From f90b29e01bbb1de056997af85847ab6344e4ed43 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 15 Jan 2025 16:59:39 -0800 Subject: [PATCH 001/115] [float8nocompile] support option to not precompute fp8 tensor for backward (#1517) --- .../float8nocompile/float8nocompile_linear.py | 164 ++++++++++++++++-- .../float8nocompile_linear_utils.py | 6 +- .../float8nocompile/test/train_test.py | 9 +- 3 files changed, 157 insertions(+), 22 deletions(-) diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear.py b/torchao/prototype/float8nocompile/float8nocompile_linear.py index 75a843e8c6..7e0eb85022 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear.py @@ -16,6 +16,7 @@ ToFP8ColumnMajor, ToFP8ColumnMajorT, ToFP8RowAndColumnMajor, + ToFP8RowMajor, ToFP8RowMajorTAndNonT, ) from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import ( @@ -36,32 +37,31 @@ def __init__(self, *args, **kwargs): Additional arguments on top of `torch.nn.Linear`'s arguments: * `config`: Float8LinearConfig """ - config = kwargs.pop("config") - kernel_algo = kwargs.pop("kernel_algo") - emulate = config.emulate + self.config = kwargs.pop("config") + self.kernel_algo = kwargs.pop("kernel_algo") + self.no_precompute_for_backward = kwargs.pop( + "no_precompute_for_backward", False + ) super().__init__(*args, **kwargs) - self.config = config - self.kernel_algo = kernel_algo - self.linear_mm_config = LinearMMConfig( # output ScaledMMConfig( - emulate, + self.config.emulate, self.config.gemm_config_output.use_fast_accum, False, self.config.pad_inner_dim, ), # grad_input ScaledMMConfig( - emulate, + self.config.emulate, self.config.gemm_config_grad_input.use_fast_accum, False, self.config.pad_inner_dim, ), # grad_weight ScaledMMConfig( - emulate, + self.config.emulate, self.config.gemm_config_grad_weight.use_fast_accum, False, self.config.pad_inner_dim, @@ -69,14 +69,22 @@ def __init__(self, *args, **kwargs): ) def forward(self, input: torch.Tensor) -> torch.Tensor: - # TODO(danielvegamyhre): support for FSDP once dependencies are implemented - output = matmul_with_args_in_hp.apply( - input, - self.weight, - self.config, - self.linear_mm_config, - self.kernel_algo, - ) + if self.no_precompute_for_backward: + output = matmul_with_args_in_hp_no_precompute_for_backward.apply( + input, + self.weight, + self.config, + self.linear_mm_config, + self.kernel_algo, + ) + else: + output = matmul_with_args_in_hp.apply( + input, + self.weight, + self.config, + self.linear_mm_config, + self.kernel_algo, + ) return output @classmethod @@ -85,6 +93,7 @@ def from_float( mod, config: Float8LinearConfig, # only default config is supported, non-defaults silently ignored kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, + no_precompute_for_backward: bool = False, ): """ Create an nn.Linear with fp8 compute from a regular nn.Linear @@ -101,6 +110,7 @@ def from_float( bias=False, config=config, kernel_algo=kernel_algo, + no_precompute_for_backward=no_precompute_for_backward, ) new_mod.weight = mod.weight new_mod.bias = mod.bias @@ -110,8 +120,20 @@ def from_float( class matmul_with_args_in_hp(torch.autograd.Function): + """FP8 matmul with args in high precision to be used in a region without AC. + FP8 tensors only needed for backward are computed as part of kernels in the forward pass, + to reduce number of kernel dispatches and increase throughput, at the cost of higher + peak memory usage.""" + @staticmethod - def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo): + def forward( + ctx, + input_hp: torch.Tensor, + weight_hp: torch.Tensor, + config: Float8LinearConfig, + linear_mm_config: LinearMMConfig, + kernel_algo: KernelAlgorithm, + ): # reshape to be 2D for triton kernels orig_input_shape = input_hp.shape input_hp = input_hp.reshape(-1, input_hp.shape[-1]) @@ -138,6 +160,7 @@ def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo): ctx.config = config ctx.linear_mm_config = linear_mm_config ctx.kernel_algo = kernel_algo + ctx.no_precompute_for_backward = False # reshape back to expected dims output = output.reshape(*orig_input_shape[:-1], output.shape[-1]) @@ -178,15 +201,118 @@ def backward(ctx, grad_output): ) grad_input = torch.mm(grad_output_fp8_row_major, weight_fp8_col_major) + # reshape grad input to match original shape + grad_input = grad_input.reshape( + *orig_grad_output_shape[:-1], grad_input.shape[-1] + ) + # grad_weight = grad_output_t @ input # apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output` # source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85 grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major) + # grad input shape + return grad_input, grad_weight, None, None, None, None + + +class matmul_with_args_in_hp_no_precompute_for_backward(torch.autograd.Function): + """FP8 matmul with args in high precision to be used in a region with AC. + FP8 tensors only needed for backward are only computed in the backward pass + when needed, to reduce peak memory usage.""" + + @staticmethod + def forward( + ctx, + input_hp: torch.Tensor, + weight_hp: torch.Tensor, + config: Float8LinearConfig, + linear_mm_config: LinearMMConfig, + kernel_algo: KernelAlgorithm, + ): + # reshape to be 2D for triton kernels + orig_input_shape = input_hp.shape + input_hp = input_hp.reshape(-1, input_hp.shape[-1]) + + # output = input @ weight_t + input_fp8_row_major = ToFP8RowMajor.apply( + input_hp, + config.cast_config_input.target_dtype, + linear_mm_config, + GemmInputRole.INPUT, + kernel_algo, + ) + weight_t_fp8_col_major = ToFP8ColumnMajorT.apply( + weight_hp, + config.cast_config_weight.target_dtype, + linear_mm_config, + GemmInputRole.WEIGHT, + kernel_algo, + ) + output = torch.mm(input_fp8_row_major, weight_t_fp8_col_major) + + # with AC we only will save the original hp input tensor and weight for backward, + # and do the necessary fp8 conversions during the backward pass. + ctx.save_for_backward(input_hp, weight_hp) + ctx.config = config + ctx.linear_mm_config = linear_mm_config + ctx.kernel_algo = kernel_algo + ctx.no_precompute_for_backward = True + + # reshape back to expected dims + output = output.reshape(*orig_input_shape[:-1], output.shape[-1]) + return output + + @staticmethod + def backward(ctx, grad_output): + # grad_output may not be contiguous in cases like: + # output.sum().backward() where grad is all 1s, so the (M,N) view of the scalar "1" + # results in a non-contiguous tensor with stride (0,0). + if not grad_output.is_contiguous(): + grad_output = grad_output.contiguous() + + input_hp, weight_hp = ctx.saved_tensors + + # reshsape to be 2D for triton kernels + orig_grad_output_shape = grad_output.shape + grad_output = grad_output.reshape(-1, grad_output.shape[-1]) + + # cast grad output to float8_e5m2 for backward + grad_output_fp8_row_major, grad_output_t_row_major = ( + ToFP8RowMajorTAndNonT.apply( + grad_output, + ctx.config.cast_config_grad_output.target_dtype, + ctx.linear_mm_config, + GemmInputRole.GRAD_OUTPUT, + ctx.kernel_algo, + ) + ) + + # grad_input = grad_output @ weight + weight_fp8_col_major = ToFP8ColumnMajor.apply( + weight_hp, + ctx.config.cast_config_weight.target_dtype, + ctx.linear_mm_config, + GemmInputRole.WEIGHT, + ctx.kernel_algo, + ) + grad_input = torch.mm(grad_output_fp8_row_major, weight_fp8_col_major) + # reshape grad input to match original shape grad_input = grad_input.reshape( *orig_grad_output_shape[:-1], grad_input.shape[-1] ) + # grad_weight = grad_output_t @ input + # apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output` + # source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85 + input_fp8_col_major = ToFP8ColumnMajor.apply( + input_hp, + ctx.config.cast_config_input.target_dtype, + ctx.linear_mm_config, + GemmInputRole.INPUT, + ctx.kernel_algo, + ) + grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major) + # grad input shape - return grad_input, grad_weight, None, None, None + return grad_input, grad_weight, None, None, None, None diff --git a/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py b/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py index 6739242f0d..7e121c559e 100644 --- a/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py +++ b/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py @@ -27,6 +27,7 @@ def convert_to_float8_nocompile_training( config: Float8LinearConfig = None, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, + no_precompute_for_backward: bool = False, ) -> nn.Module: """ Swaps `torch.nn.Linear` in `module` with `Float8LinearNoCompile`. @@ -45,7 +46,10 @@ def convert_to_float8_nocompile_training( config = Float8LinearConfig() from_float = lambda m: Float8LinearNoCompile.from_float( - m, config=config, kernel_algo=kernel_algo + m, + config=config, + kernel_algo=kernel_algo, + no_precompute_for_backward=no_precompute_for_backward, ) return swap_linear_layers( module, diff --git a/torchao/prototype/float8nocompile/test/train_test.py b/torchao/prototype/float8nocompile/test/train_test.py index 40fc2787cb..871a49219e 100644 --- a/torchao/prototype/float8nocompile/test/train_test.py +++ b/torchao/prototype/float8nocompile/test/train_test.py @@ -39,7 +39,10 @@ def model2(): @pytest.mark.parametrize( "input_shape", [(16, 32), (1, 16, 32), (2, 16, 32), (128, 8192, 32)] ) -def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int]): +@pytest.mark.parametrize("no_precompute_for_backward", [True, False]) +def test_model_weights_and_gradients( + model1, model2, input_shape: tuple[int, int], no_precompute_for_backward: bool +): assert torch.cuda.is_available() device = torch.device("cuda") @@ -48,7 +51,9 @@ def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int # compare production float8 linear conversion with no-compile version convert_to_float8_training(model2) - convert_to_float8_nocompile_training(model1) + convert_to_float8_nocompile_training( + model1, no_precompute_for_backward=no_precompute_for_backward + ) input_tensor = torch.randn( *input_shape, requires_grad=True, dtype=torch.bfloat16, device=device From 5e59b510b97d5a1cd08da59b1f6b2df6a1d8cdfd Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 15 Jan 2025 17:25:53 -0800 Subject: [PATCH 002/115] [float8nocompile] add e2e fsdp test (#1523) --- torchao/prototype/float8nocompile/.gitignore | 3 - .../float8nocompile/test/fsdp_test.py | 97 +++++++++++++++++++ 2 files changed, 97 insertions(+), 3 deletions(-) delete mode 100644 torchao/prototype/float8nocompile/.gitignore create mode 100644 torchao/prototype/float8nocompile/test/fsdp_test.py diff --git a/torchao/prototype/float8nocompile/.gitignore b/torchao/prototype/float8nocompile/.gitignore deleted file mode 100644 index 38e0f6f87e..0000000000 --- a/torchao/prototype/float8nocompile/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -kernels/autogen/ -test/activation_checkpoint_test.py -test/distributed_test.py diff --git a/torchao/prototype/float8nocompile/test/fsdp_test.py b/torchao/prototype/float8nocompile/test/fsdp_test.py new file mode 100644 index 0000000000..44c0b13b71 --- /dev/null +++ b/torchao/prototype/float8nocompile/test/fsdp_test.py @@ -0,0 +1,97 @@ +###################################################################### +# +# To run these unit tests, use the following command: +# +# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test/fsdp_test.py +# +####################################################################### +import os + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._composable.fsdp import fully_shard + +from torchao.float8.float8_linear_utils import convert_to_float8_training +from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( + convert_to_float8_nocompile_training, +) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + +if not TORCH_VERSION_AT_LEAST_2_5: + raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") + + +class TestModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(2048, 4096, bias=False), + nn.Linear(4096, 16, bias=False), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.layers(x) + + +def setup_distributed(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +@pytest.fixture +def model1(): + torch.manual_seed(0) + return TestModel() + + +@pytest.fixture +def model2(): + torch.manual_seed(0) + return TestModel() + + +def test_model_weights_and_gradients(model1, model2): + assert torch.cuda.is_available() + device = torch.device("cuda") + + setup_distributed() + + model1 = model1.to(torch.bfloat16).to(device) + model2 = model2.to(torch.bfloat16).to(device) + + # compare production float8 linear conversion with no-compile version + convert_to_float8_training(model2) + convert_to_float8_nocompile_training(model1) + + # distributed training with FSDP2 + fully_shard(model1) + fully_shard(model2) + + input_tensor = torch.randn( + 16, 2048, requires_grad=True, dtype=torch.bfloat16, device=device + ) + input_copy1 = input_tensor.clone().detach().requires_grad_(True) + input_copy2 = input_tensor.clone().detach().requires_grad_(True) + + loss_fn = nn.MSELoss() + + output1 = model1(input_copy1) + output2 = model2(input_copy2) + + loss1 = loss_fn(output1, torch.zeros_like(output1)) + loss2 = loss_fn(output2, torch.zeros_like(output2)) + + loss1.backward() + loss2.backward() + + # compare the outputs, weight gradients, and input gradients + assert torch.allclose(output1, output2, atol=0, rtol=0) + assert torch.allclose(input_copy1.grad, input_copy2.grad, atol=0, rtol=0) + for param1, param2 in zip(model1.parameters(), model2.parameters()): + assert torch.equal(param1.grad, param2.grad) + + dist.destroy_process_group() From 522f5b854a278ee9e68e80bf8213e19c9da4e547 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 15 Jan 2025 17:41:49 -0800 Subject: [PATCH 003/115] [float8nocompile] add triton kernel which does fp8 conversion to col major and transpose in col major at once (#1566) --- .../kernels/fp8_dynamic_tensorwise.py | 162 +++++++++++++++++- .../kernels/fp8_dynamic_tensorwise_test.py | 76 ++++++++ 2 files changed, 236 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py index 630e80e094..3786b52eb5 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py @@ -250,8 +250,8 @@ def to_fp8_col_major_t( block_col_offs[:, None] * output_stride_row + block_row_offs[None, :] * output_stride_col ) - out_mask = (block_row_offs[:, None] < output_num_rows) & ( - block_col_offs[None, :] < output_num_cols + out_mask = (block_col_offs[:, None] < output_num_rows) & ( + block_row_offs[None, :] < output_num_cols ) tl.store(out_ptr + out_offs, fp8_vals, mask=out_mask) @@ -381,6 +381,77 @@ def _to_fp8_row_major_t_and_non_t( tl.store(row_major_t_out_ptr + row_major_t_offs, fp8_vals.trans(1, 0), mask=mask) +@triton.autotune(configs=kernel_configs_2D, key=["num_elements"]) +@triton.jit +def _to_fp8_col_major_t_and_non_t( + input_ptr, + col_major_out_ptr, + col_major_t_out_ptr, + scale_ptr, + num_elements: int, + fp8_dtype_min: float, + fp8_dtype_max: float, + input_num_rows: int, + input_num_cols: int, + input_stride_row: int, + input_stride_col: int, + col_major_out_stride_row: int, + col_major_out_stride_col: int, + col_major_t_out_stride_row: int, + col_major_t_out_stride_col: int, + input_dtype: tl.constexpr, + output_dtype: tl.constexpr, + BLOCK_SIZE_ROWS: tl.constexpr, + BLOCK_SIZE_COLS: tl.constexpr, + EPS: tl.constexpr, +): + """ + Reads a row-major, high precision input tensor and writes 2 output tensors: + 1) fp8 col major tensor (transposed) + 2) fp8 col major tensor + """ + # col major tranposed + block_row_id = tl.program_id(axis=0) + block_col_id = tl.program_id(axis=1) + + # load scaling factor + scale = tl.load(scale_ptr).to(tl.float32) + + # load block of input tensor + block_row_start = block_row_id * BLOCK_SIZE_ROWS + block_col_start = block_col_id * BLOCK_SIZE_COLS + block_row_offs = block_row_start + tl.arange(0, BLOCK_SIZE_ROWS) + block_col_offs = block_col_start + tl.arange(0, BLOCK_SIZE_COLS) + input_offs = ( + block_row_offs[:, None] * input_stride_row + + block_col_offs[None, :] * input_stride_col + ) + mask = (block_row_offs[:, None] < input_num_rows) & ( + block_col_offs[None, :] < input_num_cols + ) + vals = tl.load(input_ptr + input_offs, mask=mask).to(input_dtype) + + # perform conversion + vals = vals * scale + fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype) + + # 1. write col-major output + out_offs = block_row_offs[:, None] + block_col_offs[None, :] * input_num_rows + tl.store(col_major_out_ptr + out_offs, fp8_vals, mask=mask) + + # 2. write tranposed col-major output + col_major_t_num_rows = input_num_cols + col_major_t_num_cols = input_num_rows + out_offs = ( + block_col_offs[:, None] * col_major_t_out_stride_row + + block_row_offs[None, :] * col_major_t_out_stride_col + ) + out_mask = (block_col_offs[:, None] < col_major_t_num_rows) & ( + block_row_offs[None, :] < col_major_t_num_cols + ) + tl.store(col_major_t_out_ptr + out_offs, fp8_vals.trans(1, 0), mask=out_mask) + + @triton.autotune(configs=kernel_configs_1D, key=["num_elements"]) @triton.jit def _amax_atomic( @@ -859,6 +930,93 @@ def hp_to_fp8_row_major_t_and_non_t( return fp8_tensor_row_major, fp8_tensor_row_major_t +def hp_to_fp8_col_major_t_and_non_t( + hp_tensor: torch.Tensor, + fp8_dtype: torch.dtype, + linear_mm_config: LinearMMConfig, + gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX, +) -> Float8Tensor: + assert hp_tensor.is_contiguous(), "input tensor must be contiguous" + + tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype] + tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype] + + fp8_dtype_min = torch.finfo(fp8_dtype).min + fp8_dtype_max = torch.finfo(fp8_dtype).max + + # compute scaling factor for tensor + scale = _hp_tensor_to_scale( + hp_tensor, + tl_input_dtype, + fp8_dtype_max, + algo, + ) + + # perform fp8 conversion + input_num_rows, input_num_cols = hp_tensor.shape + num_elements = hp_tensor.numel() + + # preallocate necessary output tensors + fp8_output_col_major = torch.empty( + (input_num_rows, input_num_cols), dtype=fp8_dtype, device=hp_tensor.device + ) + fp8_output_col_major_t = torch.empty_like( + hp_tensor.t(), + dtype=fp8_dtype, + device=hp_tensor.device, + ) + + # launch triton kernel to perform conversion + grid = lambda meta: ( + triton.cdiv(input_num_rows, meta["BLOCK_SIZE_ROWS"]), + triton.cdiv(input_num_cols, meta["BLOCK_SIZE_COLS"]), + ) + _to_fp8_col_major_t_and_non_t[grid]( + hp_tensor, + fp8_output_col_major, + fp8_output_col_major_t, + scale, + num_elements, + fp8_dtype_min, + fp8_dtype_max, + input_num_rows, + input_num_cols, + hp_tensor.stride(0), + hp_tensor.stride(1), + fp8_output_col_major.stride(0), + fp8_output_col_major.stride(1), + fp8_output_col_major_t.stride(0), + fp8_output_col_major_t.stride(1), + input_dtype=tl_input_dtype, + output_dtype=tl_output_dtype, + EPS=EPS, + ) + + # for col major we need to update the strides to reflect the new memory layout + col_major_strides = (1, input_num_rows) + fp8_output_col_major = fp8_output_col_major.as_strided( + fp8_output_col_major.size(), col_major_strides + ) + + # wrap outputs in Float8Tensors + fp8_tensor_col_major = Float8Tensor( + fp8_output_col_major, + scale, + orig_dtype=hp_tensor.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) + fp8_tensor_col_major_t = Float8Tensor( + fp8_output_col_major_t, + scale, + orig_dtype=hp_tensor.dtype, + linear_mm_config=linear_mm_config, + gemm_input_role=gemm_input_role, + ) + return fp8_tensor_col_major, fp8_tensor_col_major_t + + def _hp_tensor_to_scale( hp_tensor: torch.Tensor, tl_input_dtype: tl.core.dtype, diff --git a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py index f0dd78bc01..55a3fecd79 100644 --- a/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py +++ b/torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py @@ -8,6 +8,7 @@ KernelAlgorithm, hp_to_fp8_col_major, hp_to_fp8_col_major_t, + hp_to_fp8_col_major_t_and_non_t, hp_to_fp8_row_and_col_major, hp_to_fp8_row_major, hp_to_fp8_row_major_t, @@ -410,3 +411,78 @@ def test_fp8_hp_to_fp8_row_major_t_and_non_t( torch.float8_e4m3fn, LinearMMConfig(), ) + + +@pytest.mark.parametrize( + "algo", + [KernelAlgorithm.REDUCTION, KernelAlgorithm.ATOMIC_MAX], +) +@pytest.mark.parametrize( + "input_shape", + [(2, 4), (32, 16), (512, 512)], +) +def test_fp8_hp_to_fp8_col_major_t_and_non_t( + input_shape: tuple[int, int], algo: KernelAlgorithm +): + assert torch.cuda.is_available() + device = "cuda" + input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device) + x_bf16 = input_bf16.clone().detach().to(device) + y_bf16 = input_bf16.clone().detach().to(device) + + # production implementation + x_fp8_row_major = hp_tensor_to_float8_dynamic( + x_bf16, + torch.float8_e4m3fn, + LinearMMConfig(), + ) + x_fp8_col_major = x_fp8_row_major.t().contiguous().t() + x_fp8_col_major_t = x_fp8_row_major.t() + + # float8nocompile triton implementation + y_fp8_col_major, y_fp8_col_major_t = hp_to_fp8_col_major_t_and_non_t( + y_bf16, + torch.float8_e4m3fn, + LinearMMConfig(), + algo=algo, + ) + + # check scales + assert torch.eq(x_fp8_col_major._scale, y_fp8_col_major._scale) + assert torch.eq(x_fp8_col_major_t._scale, y_fp8_col_major_t._scale) + + # check data + assert torch.all(torch.eq(x_fp8_col_major._data, y_fp8_col_major._data)) + assert torch.all(torch.eq(x_fp8_col_major_t._data, y_fp8_col_major_t._data)) + + # check shapes + assert x_fp8_col_major.shape == y_fp8_col_major.shape + assert x_fp8_col_major_t.shape == y_fp8_col_major_t.shape + + # check strides + assert x_fp8_col_major.stride() == y_fp8_col_major.stride() + assert x_fp8_col_major_t.stride() == y_fp8_col_major_t.stride() + + # check memory layout + assert not is_row_major(x_fp8_col_major.stride()) + assert not is_row_major(y_fp8_col_major.stride()) + assert not is_row_major(x_fp8_col_major_t.stride()) + assert not is_row_major(y_fp8_col_major_t.stride()) + + # check underlying memory layout + assert ( + x_fp8_col_major._data.storage().tolist() + == y_fp8_col_major._data.storage().tolist() + ) + assert ( + x_fp8_col_major_t._data.storage().tolist() + == y_fp8_col_major_t._data.storage().tolist() + ) + + # assert that error is raised when input tensor is not contiguous + with pytest.raises(AssertionError, match="tensor must be contiguous"): + hp_to_fp8_col_major_t_and_non_t( + y_bf16.t(), # transpose so tensor memory layout is no longer contiguous + torch.float8_e4m3fn, + LinearMMConfig(), + ) From 74a15f1dd72839264eb87adfaf986cdfcc9d6781 Mon Sep 17 00:00:00 2001 From: y-sq <58683402+y-sq@users.noreply.github.com> Date: Thu, 16 Jan 2025 10:06:03 -0800 Subject: [PATCH 004/115] Add a register_replacement to fix float8 delayed scaling kernel fusion issues in torchao/float8 Differential Revision: D67758184 Pull Request resolved: https://github.com/pytorch/ao/pull/1469 --- benchmarks/float8/profile_linear_float8.py | 10 +- test/float8/test_compile.py | 68 +++++++++++ torchao/float8/README.md | 5 +- torchao/float8/__init__.py | 4 + torchao/float8/inductor_utils.py | 126 +++++++++++++++++++++ 5 files changed, 211 insertions(+), 2 deletions(-) create mode 100644 torchao/float8/inductor_utils.py diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 19fb492c32..5045956954 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -37,6 +37,7 @@ update_triton_kernels_in_prof_chome_trace_with_torch_logs, ) +from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( Float8LinearRecipeName, ScalingType, @@ -206,7 +207,7 @@ def profile_function( # by default torch.compile appends to log_file_name, so we delete it # if it exists if os.path.isfile(config.logs_file_path): - pathlib.Path.unlink(config.logs_file_path) + pathlib.Path(config.logs_file_path).unlink() torch._logging._init_logs(log_file_name=config.logs_file_path) activities = [ProfilerActivity.CPU] @@ -288,6 +289,7 @@ def main( add_inductor_metadata_to_trace: bool = True, enable_sync_amax_history: bool = True, enable_activation_checkpointing: bool = False, + enable_float8_delayed_scaling_inductor_passes: bool = False, ): assert model_type in ( "linear", @@ -325,6 +327,12 @@ def main( print( f"enable_activation_checkpointing is set to {enable_activation_checkpointing}" ) + print( + f"enable_float8_delayed_scaling_inductor_passes is set to {enable_float8_delayed_scaling_inductor_passes}" + ) + + if enable_float8_delayed_scaling_inductor_passes: + _prototype_register_float8_delayed_scaling_inductor_passes() device = "cuda" ref_dtype = torch.bfloat16 diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 32d6bdfbbd..c42ab8ee77 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -7,6 +7,7 @@ import random import sys import unittest +from dataclasses import replace from io import StringIO import pytest @@ -25,6 +26,7 @@ from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend +from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -51,6 +53,7 @@ from torchao.float8.float8_utils import config_has_stateful_scaling from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.utils import is_fbcode def _test_compile_base( @@ -465,5 +468,70 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): assert torch.equal(float8_eager._data, float8_compile._data) +@unittest.skipIf( + not is_sm_at_least_89() or not is_fbcode(), + "CUDA with float8 support not available; or not on fbcode (the test needs be run with the latest pytorch package)", +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_delayed_scaling_pattern_replacement(dtype: torch.dtype): + from torch._inductor import config as inductor_config + from torch._inductor import metrics + + inductor_config.loop_ordering_after_fusion = True + + def clear_all(): + metrics.reset() + from torch._inductor.fx_passes.post_grad import ( + pass_patterns as post_grad_patterns_all, + ) + + post_grad_patterns_all[1].clear() + post_grad_patterns_all[1].seen_patterns.clear() + + def compile_and_run_single_layer(): + random.seed(0) + torch.manual_seed(0) + x_shape = (2048, 3072) + linear_dtype = dtype + + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_() + m_ref = nn.Linear(3072, 2048, bias=True, device="cuda", dtype=linear_dtype) + + config = get_test_float8_linear_config( + ScalingType.DELAYED, + ScalingType.DELAYED, + ScalingType.DELAYED, + False, + ) + + config = replace(config, enable_amax_init=False) + + m_fp8 = StatefulFloat8Linear.from_float( + copy.deepcopy(m_ref), + config, + ) + + m_fp8 = torch.compile(m_fp8, backend="inductor", fullgraph=True) + m_ref = torch.compile(m_ref, backend="inductor", fullgraph=True) + + y_fp8 = m_fp8(x) + y_fp8.sum().backward() + + return m_fp8.weight.grad + + clear_all() + ref_output = compile_and_run_single_layer() + ref_count_kernel = metrics.generated_kernel_count + + clear_all() + _prototype_register_float8_delayed_scaling_inductor_passes() + new_output = compile_and_run_single_layer() + new_count_kernel = metrics.generated_kernel_count + + torch.equal(ref_output, new_output) + # With the pattern replacement workaround, amax reduction kernels for the 3 tensors (weight, activation, gradient) are fused. + assert ref_count_kernel == new_count_kernel + 3 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 1a87770899..8487096e6c 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -82,6 +82,9 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 if not TORCH_VERSION_AT_LEAST_2_5: raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") +# Recommended: enable additional torchinductor passes to improve the performance of delayed scaling +torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes() + # create model and sample input m = nn.Sequential( nn.Linear(2048, 4096), @@ -172,7 +175,7 @@ For small shapes, a combination of (2) and (3) leads to speedup < 1. For medium ## Scaling type vs speedup -Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling, so the observed performance of delayed scaling is close to that of dynamic scaling. As the torch.compile limitations are fixed, we expect delayed scaling to eventually become more performant compared to dynamic scaling. +Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling without workarounds. We have a prototype workaround (API subject to change) with the `torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes()` API to improve delayed scaling performance. ## torch.compile behavior vs speedup diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 3336330361..258db53be0 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -23,6 +23,9 @@ ScaledMMConfig, ) from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp +from torchao.float8.inductor_utils import ( + _prototype_register_float8_delayed_scaling_inductor_passes, +) from torchao.float8.inference import Float8MMConfig from torchao.float8.stateful_float8_linear import WeightWithDelayedFloat8CastTensor from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -54,5 +57,6 @@ "linear_requires_sync", "sync_float8_amax_and_scale_history", "precompute_float8_dynamic_scale_for_fsdp", + "_prototype_register_float8_delayed_scaling_inductor_passes", # note: Float8Tensor and Float8Linear are not public APIs ] diff --git a/torchao/float8/inductor_utils.py b/torchao/float8/inductor_utils.py new file mode 100644 index 0000000000..3e86202536 --- /dev/null +++ b/torchao/float8/inductor_utils.py @@ -0,0 +1,126 @@ +import functools +import inspect +import traceback +from collections import deque + +import torch + + +def amax_with_scaling_pattern(tensor_x_inp, scale_x, fp8_dtype, fp8_max): + tensor_x = tensor_x_inp.to(torch.float32) * scale_x + tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max) + tensor_x = tensor_x.to(fp8_dtype) + amax = torch.max(torch.abs(tensor_x_inp)) + return (tensor_x, amax) + + +def amax_with_scaling_tiled_replacement(tensor_x_inp, scale_x, fp8_dtype, fp8_max): + tensor_x = tensor_x_inp.to(torch.float32) * scale_x + tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max) + tensor_x = tensor_x.to(fp8_dtype) + amax_1 = torch.max(torch.abs(tensor_x_inp), dim=-1).values + amax = torch.max(amax_1) + return (tensor_x, amax) + + +# The amax_with_scaling_pattern will also match dynamic scaling cases, we want to avoid that. +# `scale_x` of delayed scaling comes from the previous iteration, instead of from `tensor_x_inp`. +# We check that `scale_x` is not a dependency of `tensor_x_inp` +def fp8_delayed_scaling_extra_check(match): + scale_x_inputs = deque([match.kwargs["scale_x"]]) + max_num_node_to_check = 20 # Don't traverse too many nodes + current_num_node = 0 + while len(scale_x_inputs) > 0 and current_num_node < max_num_node_to_check: + current_node = scale_x_inputs.popleft() + for n in current_node.all_input_nodes: + if n == match.kwargs["tensor_x_inp"]: + return False + scale_x_inputs.append(n) + current_num_node += 1 + return True + + +def partialize_and_update_signature(func, **kwargs): + """ + Equivalent to functools.partial but also updates the signature on returned function + """ + original_sig = inspect.signature(func) + parameters = original_sig.parameters + + new_parameters = { + key: value for key, value in parameters.items() if key not in kwargs + } + new_sig = inspect.Signature(parameters=list(new_parameters.values())) + + partial_func = functools.partial(func, **kwargs) + + def wrapper(*args, **kwargs): + return partial_func(*args, **kwargs) + + wrapper.__signature__ = new_sig # type: ignore[attr-defined] + wrapper.__name__ = func.__name__ + + return wrapper + + +def register_fp8_delayed_scaling_patterns_inner(): + from torch._inductor.fx_passes.post_grad import ( + pass_patterns as post_grad_patterns_all, + ) + from torch._inductor.pattern_matcher import fwd_only, register_replacement + + post_grad_patterns = post_grad_patterns_all[1] # medium priority + + if torch.cuda.is_available(): + for fp8_dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ]: + # torch.float16 has the same pattern as torch.bfloat16, because they both needs `tensor_x_inp.to(torch.float32)` + for dtype in [torch.float32, torch.bfloat16]: + device = "cuda" + register_replacement( + partialize_and_update_signature( + amax_with_scaling_pattern, + fp8_dtype=fp8_dtype, + fp8_max=torch.finfo(fp8_dtype).max, + ), + partialize_and_update_signature( + amax_with_scaling_tiled_replacement, + fp8_dtype=fp8_dtype, + fp8_max=torch.finfo(fp8_dtype).max, + ), + [ + torch.tensor((16, 16), device=device, dtype=dtype), + torch.tensor(2.0, device=device, dtype=torch.float32), + ], + fwd_only, + post_grad_patterns, + extra_check=fp8_delayed_scaling_extra_check, + ) + + +""" +This a short-term workaround of the delayed scaling performance issue. +It explicitly replaces `max(x)` with `max(max(x, dim=-1))`, enabling the fusion of amax scaling factor calculation and fp8 casting. + +Usage: + To use this solution, add the following line at the beginning of your user code: + torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes() +""" + + +def _prototype_register_float8_delayed_scaling_inductor_passes() -> None: + # To make the fp8 delayed scaling pattern work, we need a fix pr from inductor, https://github.com/pytorch/pytorch/pull/139321 + # Will throw the error if the pattern registration did not work, up to user to decide what to do with it + try: + register_fp8_delayed_scaling_patterns_inner() + except AssertionError as e: + if "assert pattern_repr not in _seen_patterns" in traceback.format_exc(): + print( + f"Caught duplicated patterns in register_fp8_delayed_scaling_patterns: {traceback.format_exc()}", + "\nPlease update your pytorch dependency to the latest main branch to fix it.\n", + ) + raise e From eea4d25adebd6f84c0ebe6aa92d706396855488f Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Thu, 16 Jan 2025 11:34:28 -0800 Subject: [PATCH 005/115] Update version to 0.9.0 (#1568) Update verion to 0.9.0 --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index a3df0a6959..ac39a106c4 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.8.0 +0.9.0 From f520c917abc38aeaf57bac22a870f0479450f62d Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Thu, 16 Jan 2025 17:05:39 -0800 Subject: [PATCH 006/115] Update supported dtypes for fp8 (#1573) --- torchao/quantization/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 0a3ab7bcec..ace4d8c14c 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -156,7 +156,7 @@ from torchao.quantization import quantize_, float8_weight_only quantize_(model, float8_weight_only()) ``` -This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. +Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. #### A8W8 Float8 Dynamic Quantization with Tensorwise Scaling @@ -166,7 +166,7 @@ from torchao.quantization import quantize_, float8_dynamic_activation_float8_wei quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor())) ``` -This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. +Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. ### A8W8 Float8 Dynamic Quantization with Rowwise Scaling @@ -176,7 +176,7 @@ from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_fl quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerRow())) ``` -This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. +Per-row scaling is only supported for bfloat16 weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. #### A16W6 Floating Point WeightOnly Quantization From cf453360dd3e09394657172f8e8d8da23cfbf043 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 16 Jan 2025 17:11:45 -0800 Subject: [PATCH 007/115] Relax dtype requirements for int4 and float8 quants in autoquant (#1571) * Relax dtype requirements for int4 quants in autoquant Summary: Some of the int4 quant only works with bfloat16/float16, previously we require the model to be in correct dtype to apply these in autoquant, this PR relaxes the constraints by converting weight and activation to compatible dtypes Test Plan: python test/integration/test_integration.py -k test_autoquant_int4wo Reviewers: Subscribers: Tasks: Tags: * remove prints * add float8 * run pre-commit * run pre-commit * manual format * enable bias=True test * remove print --- test/integration/test_integration.py | 125 ++++++++++++---- torchao/dtypes/uintx/marlin_sparse_layout.py | 5 + torchao/quantization/autoquant.py | 146 +++++++++++++------ 3 files changed, 207 insertions(+), 69 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index bcd8af7ad3..1087db8cf8 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -25,6 +25,9 @@ AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, AQFloat8WeightOnlyQuantizedLinearWeight, + AQGemliteInt4G64WeightOnlyQuantizedLinearWeight, + AQInt4G32WeightOnlyQuantizedLinearWeight, + AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, AQInt8DynamicallyQuantizedLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight, AQInt8WeightOnlyQuantizedLinearWeight2, @@ -1751,37 +1754,109 @@ def test_autoquant_min_sqnr(self, device, dtype): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+." ) - def test_autoquant_float(self): + def test_autoquant_hp_float(self): device = "cuda" dtype = torch.float32 m, k, n = 128, 128, 128 example_input = torch.randn(m, k, device=device, dtype=dtype) - model = ( - torch.nn.Sequential( - torch.nn.ReLU(), - torch.nn.Linear(k, n), - torch.nn.ReLU(), + for qclass in torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST: + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n, bias=True), + torch.nn.ReLU(), + ) + .to(device) + .to(dtype) ) - .to(device) - .to(dtype) - ) - ref = model(example_input) - torchao.autoquant( - model, - qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, - ) - out = model(example_input) - from torchao.quantization.autoquant import ( - BFloat16Tensor, - Float16Tensor, - Float32Tensor, - ) + ref = model(example_input) + qtensor_class_list = [qclass] + torchao.autoquant( + model, + qtensor_class_list=qtensor_class_list, + ) + out = model(example_input) + self.assertIn( + type(model[1].weight), + qtensor_class_list, + ) + self.assertGreater(compute_error(out, ref), 40) - self.assertIn( - type(model[1].weight), [Float32Tensor, Float16Tensor, BFloat16Tensor] - ) - print(compute_error(out, ref)) - self.assertGreater(compute_error(out, ref), 60) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." + ) + @unittest.skipIf(not has_gemlite, "gemlite not available") + def test_autoquant_int4wo(self, device, dtype): + if device == "cpu": + self.skipTest(f"int4wo is for cuda, not {device}") + + m, k, n = 128, 128, 128 + example_input = torch.randn(m, k, device=device, dtype=dtype) + + for qclass in [ + AQGemliteInt4G64WeightOnlyQuantizedLinearWeight, + AQInt4G32WeightOnlyQuantizedLinearWeight, + AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight, + ]: + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n, bias=True), + torch.nn.ReLU(), + ) + .to(device) + .to(dtype) + ) + ref = model(example_input) + qtensor_class_list = [qclass] + torchao.autoquant( + model, + qtensor_class_list=qtensor_class_list, + ) + out = model(example_input) + + self.assertIn(type(model[1].weight), qtensor_class_list) + self.assertGreater(compute_error(ref, out), 20) + + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." + ) + def test_autoquant_float8(self, device, dtype): + if device == "cpu": + self.skipTest(f"int4wo is for cuda, not {device}") + + # note: marlin sparse layout failed when scale_t has a dimension of 1d + m, k, n = 128, 128, 128 + example_input = torch.randn(m, k, device=device, dtype=dtype) + + for qclass in [ + AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, + AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, + AQFloat8WeightOnlyQuantizedLinearWeight, + ]: + model = ( + torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k, n, bias=True), + torch.nn.ReLU(), + ) + .to(device) + .to(dtype) + ) + ref = model(example_input) + qtensor_class_list = [qclass] + torchao.autoquant( + model, + qtensor_class_list=qtensor_class_list, + ) + out = model(example_input) + + self.assertIn(type(model[1].weight), qtensor_class_list) + self.assertGreater(compute_error(ref, out), 20) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.") diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index e37623182a..2a84dd1813 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -227,6 +227,11 @@ def from_plain( # Linear layers are (in_features, out_features) but the int_data that is reaching this point # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. q_w_24 = int_data.t() + # addressing the case when scale has dimension 1, happens when + # weight_shape[-1] == group_size == 128 + if scale.ndim == 1: + scale = scale.reshape(scale.shape[0], -1) + scale_t = scale.t() if not torch.cuda.get_device_capability()[0] >= 8: diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index d506d2b65e..d49e84e066 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -16,6 +16,7 @@ from torchao.kernel import safe_int_mm from torchao.quantization.linear_activation_quantized_tensor import ( LinearActivationQuantizedTensor, + to_linear_activation_quantized, ) from torchao.quantization.quant_primitives import ( MappingType, @@ -370,6 +371,18 @@ def _is_interpolate_mode(mode): return False +def _to_float16(x: torch.Tensor) -> torch.Tensor: + return x.to(torch.float16) + + +def _to_bfloat16(x: torch.Tensor) -> torch.Tensor: + return x.to(torch.bfloat16) + + +def _identity(x: torch.Tensor) -> torch.Tensor: + return x + + class AQMixin: """ Tests and benchmarks the autoquantization process for the given activation matrix, weight, and bias. @@ -610,9 +623,11 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): return y -class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): +class AQInt4G32WeightOnlyQuantizedLinearWeight( + LinearActivationQuantizedTensor, AQMixin +): """ - AutoQuantizable version of Int4WeightOnlyQuantizedLinearWeight + AutoQuantizable version of int4_weight_only """ group_size: int = 32 @@ -621,20 +636,30 @@ class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): @classmethod def from_float(cls, weight): + from torchao.dtypes import to_affine_quantized_intx + group_size = cls.group_size _layout = cls.aq_layout if weight.shape[-1] % group_size != 0: return weight + input_quant_func = None + + # NOTE: we only convert activation dtype and weight dtype here + # because the kernel implementation for both TensorCoreTiledLayout and MarlinSparseLayout + # can work with multiple bias dtypes (by converting bias to the dtype of activation) if ( isinstance(_layout, TensorCoreTiledLayout) and weight.dtype != torch.bfloat16 ): - return weight - - if isinstance(_layout, MarlinSparseLayout) and weight.dtype != torch.float16: - return weight + weight = weight.to(torch.bfloat16) + input_quant_func = _to_bfloat16 + elif isinstance(_layout, MarlinSparseLayout) and weight.dtype != torch.float16: + weight = weight.to(torch.float16) + input_quant_func = _to_float16 + else: + input_quant_func = _identity use_hqq = True mapping_type = MappingType.ASYMMETRIC @@ -653,7 +678,7 @@ def from_float(cls, weight): zero_point_domain = ZeroPointDomain.INT use_hqq = False - return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx( + weight = to_affine_quantized_intx( weight, mapping_type, block_size, @@ -668,6 +693,10 @@ def from_float(cls, weight): use_hqq=use_hqq, ) + return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_float( + weight, input_quant_func + ) + class AQInt4G64WeightOnlyQuantizedLinearWeight( AQInt4G32WeightOnlyQuantizedLinearWeight @@ -694,16 +723,19 @@ class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight( aq_layout: Layout = MarlinSparseLayout() -class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): +class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight( + LinearActivationQuantizedTensor, AQMixin +): group_size: int = 32 @classmethod def from_float(cls, weight): - if weight.dtype != torch.float16: - return weight - + from torchao.dtypes import to_affine_quantized_intx from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs + if weight.dtype != torch.float16: + weight = weight.to(torch.float16) + bit_width = 4 packing_bitwidth = 32 contiguous = None @@ -711,9 +743,12 @@ def from_float(cls, weight): aqt_kwargs = get_gemlite_aqt_kwargs( weight, cls.group_size, bit_width, packing_bitwidth, contiguous, use_hqq ) - return super( - AQGemliteInt4G32WeightOnlyQuantizedLinearWeight, cls - ).from_hp_to_intx(weight, **aqt_kwargs) + weight = to_affine_quantized_intx(weight, **aqt_kwargs) + input_quant_func = _to_float16 + + return super(AQGemliteInt4G32WeightOnlyQuantizedLinearWeight, cls).from_float( + weight, input_quant_func + ) class AQGemliteInt4G64WeightOnlyQuantizedLinearWeight( @@ -755,11 +790,24 @@ def from_float(cls, weight): return weight +# TODO: remove skip_weight_conversion arg class Float32Tensor(TorchAOBaseTensor): """Tensor subclass tensor for fp32 dtype""" - def __init__(self, weight): - self.weight = weight.to(torch.float32) + @staticmethod + def __new__(cls, weight, skip_weight_conversion=False): + kwargs = {} + kwargs["device"] = weight.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else weight.layout + ) + kwargs["dtype"] = weight.dtype + kwargs["requires_grad"] = False + shape = weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, weight, skip_weight_conversion=False): + self.weight = weight if skip_weight_conversion else weight.to(torch.float32) @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): @@ -778,7 +826,7 @@ def _apply_fn_to_data(self, fn): @classmethod def from_float(cls, weight): - return Float32Tensor(weight) + return cls(weight) @Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default]) @@ -816,8 +864,8 @@ def _(func, types, args, kwargs): class BFloat16Tensor(Float32Tensor): - def __init__(self, weight): - self.weight = weight.to(torch.bfloat16) + def __init__(self, weight, skip_weight_conversion=False): + self.weight = weight if skip_weight_conversion else weight.to(torch.bfloat16) @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): @@ -830,13 +878,13 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): ).to(dtype=orig_dtype) @classmethod - def from_float(cls, weight): - return BFloat16Tensor(weight) + def from_float(cls, weight, skip_weight_conversion=False): + return cls(weight, skip_weight_conversion) class Float16Tensor(Float32Tensor): - def __init__(self, weight): - self.weight = weight.to(torch.float16) + def __init__(self, weight, skip_weight_conversion=False): + self.weight = weight if skip_weight_conversion else weight.to(torch.float16) @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): @@ -849,8 +897,8 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): ).to(dtype=orig_dtype) @classmethod - def from_float(cls, weight): - return Float16Tensor(weight) + def from_float(cls, weight, skip_weight_conversion=False): + return cls(weight, skip_weight_conversion) class AQFloat32LinearWeight(Float32Tensor, AQMixin): @@ -911,9 +959,7 @@ def from_float(cls, weight): ) -class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight( - AQMixin, LinearActivationQuantizedTensor -): +class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, BFloat16Tensor): """ AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per row scaling """ @@ -942,12 +988,13 @@ def get_per_token_block_size(x): input_target_dtype = torch.float8_e4m3fn _layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True)) # TODO: make this serializable - input_quant_func = lambda x: _input_activation_quant_func_fp8( - x=x, - activation_granularity=cls.activation_granularity, - activation_dtype=input_target_dtype, - ) + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": cls.activation_granularity, + "activation_dtype": input_target_dtype, + } block_size = get_weight_block_size(weight) + weight = to_affine_quantized_floatx( input_float=weight, block_size=block_size, @@ -955,10 +1002,15 @@ def get_per_token_block_size(x): _layout=_layout, scale_dtype=torch.float32, ) - weight = super( + weight = to_linear_activation_quantized( + weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) + # at inference time, + # we first convert the input, weight and bias to bfloat16, and then quantize activation + # and then dispatch to the quantized ops + return super( AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls - ).from_float(weight, input_quant_func) - return weight + ).from_float(weight, skip_weight_conversion=True) class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight( @@ -982,15 +1034,14 @@ def get_weight_block_size(x): return x.shape target_dtype = torch.float8_e4m3fn - input_target_dtype = torch.float8_e4m3fn _layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True)) - # TODO: make this serializable - input_quant_func = lambda x: _input_activation_quant_func_fp8( - x=x, - activation_granularity=cls.activation_granularity, - activation_dtype=input_target_dtype, - ) + # TODO: test serializable + input_quant_func = _input_activation_quant_func_fp8 + input_quant_args = { + "activation_granularity": cls.activation_granularity, + "activation_dtype": input_target_dtype, + } block_size = get_weight_block_size(weight) weight = to_affine_quantized_floatx( input_float=weight, @@ -1001,7 +1052,7 @@ def get_weight_block_size(x): ) weight = super( AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls - ).from_float(weight, input_quant_func) + ).from_float(weight, input_quant_func, input_quant_args) return weight @@ -1299,3 +1350,10 @@ def finalize_autoquant(): if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals(ALL_AUTOQUANT_CLASS_LIST) + torch.serialization.add_safe_globals( + [ + _to_float16, + _to_bfloat16, + _identity, + ] + ) From d96c6a79adcf1f4fa127b0cd7f762921bb951c8a Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 17 Jan 2025 08:35:50 -0800 Subject: [PATCH 008/115] Enable ROCM in CI (#999) * Enable ROCM in CI --------- Co-authored-by: amdfaa <107946068+amdfaa@users.noreply.github.com> --- .github/workflows/regression_test.yml | 13 ++++++++++--- torchao/utils.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 74b39d2ef2..eaf2e3cbbb 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -17,6 +17,10 @@ concurrency: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} +permissions: + id-token: write + contents: read + jobs: test-nightly: strategy: @@ -33,10 +37,16 @@ jobs: torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu' gpu-arch-type: "cpu" gpu-arch-version: "" + - name: ROCM Nightly + runs-on: linux.rocm.gpu.2 + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3' + gpu-arch-type: "rocm" + gpu-arch-version: "6.3" uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 120 + no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }} runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} @@ -71,7 +81,6 @@ jobs: torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - - name: CPU 2.3 runs-on: linux.4xlarge torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu' @@ -99,8 +108,6 @@ jobs: conda create -n venv python=3.9 -y conda activate venv echo "::group::Install newer objcopy that supports --set-section-alignment" - yum install -y devtoolset-10-binutils - export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH python -m pip install --upgrade pip pip install ${{ matrix.torch-spec }} pip install -r dev-requirements.txt diff --git a/torchao/utils.py b/torchao/utils.py index 7a17c1b104..4729675a14 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -607,7 +607,7 @@ def _torch_version_at_least(min_version): def is_MI300(): if torch.cuda.is_available() and torch.version.hip: mxArchName = ["gfx940", "gfx941", "gfx942"] - archName = torch.cuda.get_device_properties().gcnArchName + archName = torch.cuda.get_device_properties(0).gcnArchName for arch in mxArchName: if arch in archName: return True From a1c67b98905e81e51d56c1558742ca7e0fff49c1 Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Fri, 17 Jan 2025 08:40:49 -0800 Subject: [PATCH 009/115] Skip Unit Tests for ROCm CI (#1563) * skip failing unit tests for ROCm CI * fix util import --- test/__init__.py | 0 test/dtypes/test_affine_quantized.py | 4 ++++ test/dtypes/test_floatx.py | 2 ++ test/float8/test_base.py | 3 +++ test/hqq/test_hqq_affine.py | 2 ++ test/integration/test_integration.py | 7 +++++++ test/kernel/test_galore_downproj.py | 2 ++ test/prototype/test_awq.py | 3 +++ test/prototype/test_low_bit_optim.py | 2 ++ test/prototype/test_splitk.py | 3 +++ test/quantization/test_galore_quant.py | 2 ++ test/quantization/test_marlin_qqq.py | 3 +++ test/sparsity/test_marlin.py | 4 +++- test/test_ops.py | 3 +++ test/test_s8s4_linear_cutlass.py | 3 +++ test/test_utils.py | 29 ++++++++++++++++++++++++++ 16 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 test/__init__.py diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index f08ba7aa72..88e133ccf8 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -2,6 +2,7 @@ import unittest import torch +from test_utils import skip_if_rocm from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( TestCase, @@ -89,6 +90,7 @@ def test_tensor_core_layout_transpose(self): aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) + @skip_if_rocm("ROCm development in progress") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize( "apply_quant", get_quantization_functions(True, True, "cuda", True) @@ -168,6 +170,7 @@ def apply_uint6_weight_only_quant(linear): deregister_aqt_quantized_linear_dispatch(dispatch_condition) + @skip_if_rocm("ROCm development in progress") @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_print_quantized_module(self, apply_quant): @@ -180,6 +183,7 @@ class TestAffineQuantizedBasic(TestCase): COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.bfloat16] + @skip_if_rocm("ROCm development in progress") @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) def test_flatten_unflatten(self, device, dtype): diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 8bb39b2cc8..ea30edfe38 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -2,6 +2,7 @@ import unittest import torch +from test_utils import skip_if_rocm from torch.testing._internal.common_utils import ( TestCase, instantiate_parametrized_tests, @@ -108,6 +109,7 @@ def test_to_copy_device(self, ebits, mbits): @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("bias", [False, True]) @parametrize("dtype", [torch.half, torch.bfloat16]) + @skip_if_rocm("ROCm development in progress") @unittest.skipIf(is_fbcode(), reason="broken in fbcode") def test_fpx_weight_only(self, ebits, mbits, bias, dtype): N, OC, IC = 4, 256, 64 diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 3e894c02b9..c20920fb9f 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -24,6 +24,8 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) +from test_utils import skip_if_rocm + from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -423,6 +425,7 @@ def test_linear_from_config_params( @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_bias", [True, False]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skip_if_rocm("ROCm development in progress") def test_linear_from_recipe( self, recipe_name, diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 381886d594..4c85ee2c30 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -1,6 +1,7 @@ import unittest import torch +from test_utils import skip_if_rocm from torchao.quantization import ( MappingType, @@ -110,6 +111,7 @@ def test_hqq_plain_5bit(self): ref_dot_product_error=0.000704, ) + @skip_if_rocm("ROCm development in progress") def test_hqq_plain_4bit(self): self._test_hqq( dtype=torch.uint4, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 1087db8cf8..935f5021f1 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -93,6 +93,8 @@ except ModuleNotFoundError: has_gemlite = False +from test_utils import skip_if_rocm + logger = logging.getLogger("INFO") torch.manual_seed(0) @@ -569,6 +571,7 @@ def test_per_token_linear_cpu(self): self._test_per_token_linear_impl("cpu", dtype) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_per_token_linear_cuda(self): for dtype in (torch.float32, torch.float16, torch.bfloat16): self._test_per_token_linear_impl("cuda", dtype) @@ -687,6 +690,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm development in progress") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -706,6 +710,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm development in progress") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -899,6 +904,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm development in progress") def test_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -918,6 +924,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm development in progress") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") diff --git a/test/kernel/test_galore_downproj.py b/test/kernel/test_galore_downproj.py index bab65fc2fb..d7f8102f9f 100644 --- a/test/kernel/test_galore_downproj.py +++ b/test/kernel/test_galore_downproj.py @@ -8,6 +8,7 @@ import torch from galore_test_utils import make_data +from test_utils import skip_if_rocm from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk from torchao.prototype.galore.kernels.matmul import triton_mm_launcher @@ -29,6 +30,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS) +@skip_if_rocm("ROCm development in progress") def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype): torch.backends.cuda.matmul.allow_tf32 = allow_tf32 MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 1b91983bc0..3843d0e0cd 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -10,6 +10,8 @@ if TORCH_VERSION_AT_LEAST_2_3: from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ +from test_utils import skip_if_rocm + class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): @@ -113,6 +115,7 @@ def test_awq_loading(device, qdtype): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@skip_if_rocm("ROCm development in progress") def test_save_weights_only(): dataset_size = 100 l1, l2, l3 = 512, 256, 128 diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index acc7576e56..8f5dccdac5 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -42,6 +42,7 @@ except ImportError: lpmm = None +from test_utils import skip_if_rocm _DEVICES = get_available_devices() @@ -112,6 +113,7 @@ class TestOptim(TestCase): ) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) + @skip_if_rocm("ROCm development in progress") def test_optim_smoke(self, optim_name, dtype, device): if optim_name.endswith("Fp8") and device == "cuda": if not TORCH_VERSION_AT_LEAST_2_4: diff --git a/test/prototype/test_splitk.py b/test/prototype/test_splitk.py index 48793ba907..cd90408644 100644 --- a/test/prototype/test_splitk.py +++ b/test/prototype/test_splitk.py @@ -13,6 +13,8 @@ except ImportError: triton_available = False +from test_utils import skip_if_rocm + from torchao.utils import skip_if_compute_capability_less_than @@ -20,6 +22,7 @@ @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") class TestFP8Gemm(TestCase): @skip_if_compute_capability_less_than(9.0) + @skip_if_rocm("ROCm development in progress") def test_gemm_split_k(self): dtype = torch.float16 qdtype = torch.float8_e4m3fn diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 3eb9b0a2c5..47020d6b26 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -13,6 +13,7 @@ dequantize_blockwise, quantize_blockwise, ) +from test_utils import skip_if_rocm from torchao.prototype.galore.kernels import ( triton_dequant_blockwise, @@ -82,6 +83,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): "dim1,dim2,dtype,signed,blocksize", TEST_CONFIGS, ) +@skip_if_rocm("ROCm development in progress") def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index ebdf2281e0..c21922b631 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -3,6 +3,7 @@ import pytest import torch +from test_utils import skip_if_rocm from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests @@ -45,6 +46,7 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_marlin_qqq(self): output_ref = self.model(self.input) for group_size in [-1, 128]: @@ -66,6 +68,7 @@ def test_marlin_qqq(self): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_marlin_qqq_compile(self): model_copy = copy.deepcopy(self.model) model_copy.forward = torch.compile(model_copy.forward, fullgraph=True) diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 4da7304a24..a78940656b 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -2,6 +2,7 @@ import pytest import torch +from test_utils import skip_if_rocm from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests @@ -37,6 +38,7 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_quant_sparse_marlin_layout_eager(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) @@ -48,13 +50,13 @@ def test_quant_sparse_marlin_layout_eager(self): # Sparse + quantized quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) sparse_result = self.model(self.input) - assert torch.allclose( dense_result, sparse_result, atol=3e-1 ), "Results are not close" @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_quant_sparse_marlin_layout_compile(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) diff --git a/test/test_ops.py b/test/test_ops.py index 26671ddf40..5a60a50e00 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -19,6 +19,9 @@ from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + if is_fbcode(): pytest.skip( "Skipping the test in fbcode since we don't have TARGET file for kernels" diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py index 6510adaea3..93f842b2d8 100644 --- a/test/test_s8s4_linear_cutlass.py +++ b/test/test_s8s4_linear_cutlass.py @@ -7,6 +7,9 @@ from torchao.quantization.utils import group_quantize_tensor_symmetric from torchao.utils import compute_max_diff +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] S8S4_LINEAR_CUTLASS_SIZE_MNK = [ diff --git a/test/test_utils.py b/test/test_utils.py index 77a8b39aae..d4bcb7ffe0 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,11 +1,40 @@ +import functools import unittest from unittest.mock import patch +import pytest import torch from torchao.utils import TorchAOBaseTensor, torch_version_at_least +def skip_if_rocm(message=None): + """Decorator to skip tests on ROCm platform with custom message. + + Args: + message (str, optional): Additional information about why the test is skipped. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if torch.version.hip is not None: + skip_message = "Skipping the test in ROCm" + if message: + skip_message += f": {message}" + pytest.skip(skip_message) + return func(*args, **kwargs) + + return wrapper + + # Handle both @skip_if_rocm and @skip_if_rocm() syntax + if callable(message): + func = message + message = None + return decorator(func) + return decorator + + class TestTorchVersionAtLeast(unittest.TestCase): def test_torch_version_at_least(self): test_cases = [ From 69f3795a7b60bdc6b042b6c996f8c174fcd850c6 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 17 Jan 2025 14:03:11 -0500 Subject: [PATCH 010/115] Delete unused QAT utils code (#1579) --- torchao/quantization/qat/utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torchao/quantization/qat/utils.py b/torchao/quantization/qat/utils.py index c901d59e92..80e909f48a 100644 --- a/torchao/quantization/qat/utils.py +++ b/torchao/quantization/qat/utils.py @@ -16,14 +16,6 @@ _get_per_token_block_size, ) -# Attribute name representing the forward prehook wrapping the -# linear input in an `AffineFakeQuantizedTensor` on a linear module. -# -# The value of this attribute is a 2-tuple of (prehook, handle). -# The prehook can be disabled by calling `handle.remove()`, and -# re-enabled by calling `module.register_forward_pre_hook(prehook)`. -_QAT_LINEAR_SUBCLASS_INPUT_PREHOOK = "_qat_linear_subclass_input_prehook" - class _GenericFakeQuantize(torch.autograd.Function): """ From 9afaabb405b82d94c7c7cea97b87730fa8f25bad Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 17 Jan 2025 14:11:05 -0500 Subject: [PATCH 011/115] Revert "Skip Unit Tests for ROCm CI" (#1580) Revert "Skip Unit Tests for ROCm CI (#1563)" This reverts commit a1c67b98905e81e51d56c1558742ca7e0fff49c1. --- test/__init__.py | 0 test/dtypes/test_affine_quantized.py | 4 ---- test/dtypes/test_floatx.py | 2 -- test/float8/test_base.py | 3 --- test/hqq/test_hqq_affine.py | 2 -- test/integration/test_integration.py | 7 ------- test/kernel/test_galore_downproj.py | 2 -- test/prototype/test_awq.py | 3 --- test/prototype/test_low_bit_optim.py | 2 -- test/prototype/test_splitk.py | 3 --- test/quantization/test_galore_quant.py | 2 -- test/quantization/test_marlin_qqq.py | 3 --- test/sparsity/test_marlin.py | 4 +--- test/test_ops.py | 3 --- test/test_s8s4_linear_cutlass.py | 3 --- test/test_utils.py | 29 -------------------------- 16 files changed, 1 insertion(+), 71 deletions(-) delete mode 100644 test/__init__.py diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 88e133ccf8..f08ba7aa72 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -2,7 +2,6 @@ import unittest import torch -from test_utils import skip_if_rocm from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( TestCase, @@ -90,7 +89,6 @@ def test_tensor_core_layout_transpose(self): aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) - @skip_if_rocm("ROCm development in progress") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize( "apply_quant", get_quantization_functions(True, True, "cuda", True) @@ -170,7 +168,6 @@ def apply_uint6_weight_only_quant(linear): deregister_aqt_quantized_linear_dispatch(dispatch_condition) - @skip_if_rocm("ROCm development in progress") @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_print_quantized_module(self, apply_quant): @@ -183,7 +180,6 @@ class TestAffineQuantizedBasic(TestCase): COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.bfloat16] - @skip_if_rocm("ROCm development in progress") @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) def test_flatten_unflatten(self, device, dtype): diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index ea30edfe38..8bb39b2cc8 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -2,7 +2,6 @@ import unittest import torch -from test_utils import skip_if_rocm from torch.testing._internal.common_utils import ( TestCase, instantiate_parametrized_tests, @@ -109,7 +108,6 @@ def test_to_copy_device(self, ebits, mbits): @parametrize("ebits,mbits", _Floatx_DTYPES) @parametrize("bias", [False, True]) @parametrize("dtype", [torch.half, torch.bfloat16]) - @skip_if_rocm("ROCm development in progress") @unittest.skipIf(is_fbcode(), reason="broken in fbcode") def test_fpx_weight_only(self, ebits, mbits, bias, dtype): N, OC, IC = 4, 256, 64 diff --git a/test/float8/test_base.py b/test/float8/test_base.py index c20920fb9f..3e894c02b9 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -24,8 +24,6 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) -from test_utils import skip_if_rocm - from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -425,7 +423,6 @@ def test_linear_from_config_params( @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_bias", [True, False]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @skip_if_rocm("ROCm development in progress") def test_linear_from_recipe( self, recipe_name, diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 4c85ee2c30..381886d594 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -1,7 +1,6 @@ import unittest import torch -from test_utils import skip_if_rocm from torchao.quantization import ( MappingType, @@ -111,7 +110,6 @@ def test_hqq_plain_5bit(self): ref_dot_product_error=0.000704, ) - @skip_if_rocm("ROCm development in progress") def test_hqq_plain_4bit(self): self._test_hqq( dtype=torch.uint4, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 935f5021f1..1087db8cf8 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -93,8 +93,6 @@ except ModuleNotFoundError: has_gemlite = False -from test_utils import skip_if_rocm - logger = logging.getLogger("INFO") torch.manual_seed(0) @@ -571,7 +569,6 @@ def test_per_token_linear_cpu(self): self._test_per_token_linear_impl("cpu", dtype) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @skip_if_rocm("ROCm development in progress") def test_per_token_linear_cuda(self): for dtype in (torch.float32, torch.float16, torch.bfloat16): self._test_per_token_linear_impl("cuda", dtype) @@ -690,7 +687,6 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") - @skip_if_rocm("ROCm development in progress") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -710,7 +706,6 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") - @skip_if_rocm("ROCm development in progress") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -904,7 +899,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") - @skip_if_rocm("ROCm development in progress") def test_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -924,7 +918,6 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") - @skip_if_rocm("ROCm development in progress") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") diff --git a/test/kernel/test_galore_downproj.py b/test/kernel/test_galore_downproj.py index d7f8102f9f..bab65fc2fb 100644 --- a/test/kernel/test_galore_downproj.py +++ b/test/kernel/test_galore_downproj.py @@ -8,7 +8,6 @@ import torch from galore_test_utils import make_data -from test_utils import skip_if_rocm from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk from torchao.prototype.galore.kernels.matmul import triton_mm_launcher @@ -30,7 +29,6 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS) -@skip_if_rocm("ROCm development in progress") def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype): torch.backends.cuda.matmul.allow_tf32 = allow_tf32 MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 3843d0e0cd..1b91983bc0 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -10,8 +10,6 @@ if TORCH_VERSION_AT_LEAST_2_3: from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ -from test_utils import skip_if_rocm - class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): @@ -115,7 +113,6 @@ def test_awq_loading(device, qdtype): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@skip_if_rocm("ROCm development in progress") def test_save_weights_only(): dataset_size = 100 l1, l2, l3 = 512, 256, 128 diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 8f5dccdac5..acc7576e56 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -42,7 +42,6 @@ except ImportError: lpmm = None -from test_utils import skip_if_rocm _DEVICES = get_available_devices() @@ -113,7 +112,6 @@ class TestOptim(TestCase): ) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) - @skip_if_rocm("ROCm development in progress") def test_optim_smoke(self, optim_name, dtype, device): if optim_name.endswith("Fp8") and device == "cuda": if not TORCH_VERSION_AT_LEAST_2_4: diff --git a/test/prototype/test_splitk.py b/test/prototype/test_splitk.py index cd90408644..48793ba907 100644 --- a/test/prototype/test_splitk.py +++ b/test/prototype/test_splitk.py @@ -13,8 +13,6 @@ except ImportError: triton_available = False -from test_utils import skip_if_rocm - from torchao.utils import skip_if_compute_capability_less_than @@ -22,7 +20,6 @@ @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") class TestFP8Gemm(TestCase): @skip_if_compute_capability_less_than(9.0) - @skip_if_rocm("ROCm development in progress") def test_gemm_split_k(self): dtype = torch.float16 qdtype = torch.float8_e4m3fn diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 47020d6b26..3eb9b0a2c5 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -13,7 +13,6 @@ dequantize_blockwise, quantize_blockwise, ) -from test_utils import skip_if_rocm from torchao.prototype.galore.kernels import ( triton_dequant_blockwise, @@ -83,7 +82,6 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): "dim1,dim2,dtype,signed,blocksize", TEST_CONFIGS, ) -@skip_if_rocm("ROCm development in progress") def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index c21922b631..ebdf2281e0 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -3,7 +3,6 @@ import pytest import torch -from test_utils import skip_if_rocm from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests @@ -46,7 +45,6 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") - @skip_if_rocm("ROCm development in progress") def test_marlin_qqq(self): output_ref = self.model(self.input) for group_size in [-1, 128]: @@ -68,7 +66,6 @@ def test_marlin_qqq(self): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") - @skip_if_rocm("ROCm development in progress") def test_marlin_qqq_compile(self): model_copy = copy.deepcopy(self.model) model_copy.forward = torch.compile(model_copy.forward, fullgraph=True) diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index a78940656b..4da7304a24 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -2,7 +2,6 @@ import pytest import torch -from test_utils import skip_if_rocm from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests @@ -38,7 +37,6 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") - @skip_if_rocm("ROCm development in progress") def test_quant_sparse_marlin_layout_eager(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) @@ -50,13 +48,13 @@ def test_quant_sparse_marlin_layout_eager(self): # Sparse + quantized quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) sparse_result = self.model(self.input) + assert torch.allclose( dense_result, sparse_result, atol=3e-1 ), "Results are not close" @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") - @skip_if_rocm("ROCm development in progress") def test_quant_sparse_marlin_layout_compile(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) diff --git a/test/test_ops.py b/test/test_ops.py index 5a60a50e00..26671ddf40 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -19,9 +19,6 @@ from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode -if torch.version.hip is not None: - pytest.skip("Skipping the test in ROCm", allow_module_level=True) - if is_fbcode(): pytest.skip( "Skipping the test in fbcode since we don't have TARGET file for kernels" diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py index 93f842b2d8..6510adaea3 100644 --- a/test/test_s8s4_linear_cutlass.py +++ b/test/test_s8s4_linear_cutlass.py @@ -7,9 +7,6 @@ from torchao.quantization.utils import group_quantize_tensor_symmetric from torchao.utils import compute_max_diff -if torch.version.hip is not None: - pytest.skip("Skipping the test in ROCm", allow_module_level=True) - S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] S8S4_LINEAR_CUTLASS_SIZE_MNK = [ diff --git a/test/test_utils.py b/test/test_utils.py index d4bcb7ffe0..77a8b39aae 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,40 +1,11 @@ -import functools import unittest from unittest.mock import patch -import pytest import torch from torchao.utils import TorchAOBaseTensor, torch_version_at_least -def skip_if_rocm(message=None): - """Decorator to skip tests on ROCm platform with custom message. - - Args: - message (str, optional): Additional information about why the test is skipped. - """ - - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - if torch.version.hip is not None: - skip_message = "Skipping the test in ROCm" - if message: - skip_message += f": {message}" - pytest.skip(skip_message) - return func(*args, **kwargs) - - return wrapper - - # Handle both @skip_if_rocm and @skip_if_rocm() syntax - if callable(message): - func = message - message = None - return decorator(func) - return decorator - - class TestTorchVersionAtLeast(unittest.TestCase): def test_torch_version_at_least(self): test_cases = [ From 1240b19fd719d54af64c2d4d8b5cc33aba345dce Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 17 Jan 2025 14:11:43 -0500 Subject: [PATCH 012/115] Revert "Enable ROCM in CI" (#1583) Revert "Enable ROCM in CI (#999)" This reverts commit d96c6a79adcf1f4fa127b0cd7f762921bb951c8a. --- .github/workflows/regression_test.yml | 13 +++---------- torchao/utils.py | 2 +- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index eaf2e3cbbb..74b39d2ef2 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -17,10 +17,6 @@ concurrency: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} -permissions: - id-token: write - contents: read - jobs: test-nightly: strategy: @@ -37,16 +33,10 @@ jobs: torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu' gpu-arch-type: "cpu" gpu-arch-version: "" - - name: ROCM Nightly - runs-on: linux.rocm.gpu.2 - torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3' - gpu-arch-type: "rocm" - gpu-arch-version: "6.3" uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 120 - no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }} runner: ${{ matrix.runs-on }} gpu-arch-type: ${{ matrix.gpu-arch-type }} gpu-arch-version: ${{ matrix.gpu-arch-version }} @@ -81,6 +71,7 @@ jobs: torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121' gpu-arch-type: "cuda" gpu-arch-version: "12.1" + - name: CPU 2.3 runs-on: linux.4xlarge torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu' @@ -108,6 +99,8 @@ jobs: conda create -n venv python=3.9 -y conda activate venv echo "::group::Install newer objcopy that supports --set-section-alignment" + yum install -y devtoolset-10-binutils + export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH python -m pip install --upgrade pip pip install ${{ matrix.torch-spec }} pip install -r dev-requirements.txt diff --git a/torchao/utils.py b/torchao/utils.py index 4729675a14..7a17c1b104 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -607,7 +607,7 @@ def _torch_version_at_least(min_version): def is_MI300(): if torch.cuda.is_available() and torch.version.hip: mxArchName = ["gfx940", "gfx941", "gfx942"] - archName = torch.cuda.get_device_properties(0).gcnArchName + archName = torch.cuda.get_device_properties().gcnArchName for arch in mxArchName: if arch in archName: return True From 32d9b0bc05e4cce0bd18438b02cb819891d36a49 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 17 Jan 2025 16:06:05 -0800 Subject: [PATCH 013/115] Fix CI linux_job permissions (#1576) --- .github/workflows/float8_test.yml | 3 +++ .github/workflows/nightly_smoke_test.yml | 6 ++++-- .github/workflows/regression_test.yml | 3 +++ test/integration/test_integration.py | 2 +- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.github/workflows/float8_test.yml b/.github/workflows/float8_test.yml index 75482c9e24..7c9e5a4b00 100644 --- a/.github/workflows/float8_test.yml +++ b/.github/workflows/float8_test.yml @@ -29,6 +29,9 @@ jobs: gpu-arch-type: "cuda" gpu-arch-version: "12.1" + permissions: + id-token: write + contents: read uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 60 diff --git a/.github/workflows/nightly_smoke_test.yml b/.github/workflows/nightly_smoke_test.yml index d215f22ed2..18d4f41af6 100644 --- a/.github/workflows/nightly_smoke_test.yml +++ b/.github/workflows/nightly_smoke_test.yml @@ -11,7 +11,7 @@ concurrency: cancel-in-progress: true env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} + HF_TOKEN: ${{ secrets.HF_TOKEN }} jobs: test: @@ -25,7 +25,9 @@ jobs: gpu-arch-type: "cuda" gpu-arch-version: "12.1" - + permissions: + id-token: write + contents: read uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: ${{ matrix.runs-on }} diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 74b39d2ef2..19c033c4d1 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -34,6 +34,9 @@ jobs: gpu-arch-type: "cpu" gpu-arch-version: "" + permissions: + id-token: write + contents: read uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: timeout: 120 diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 1087db8cf8..c926cee060 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1821,7 +1821,7 @@ def test_autoquant_int4wo(self, device, dtype): self.assertGreater(compute_error(ref, out), 20) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90") @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+." ) From ea7910e5c24523ea901aabe7945ce7ac0ffa1033 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= <115986737+alexsamardzic@users.noreply.github.com> Date: Tue, 21 Jan 2025 21:30:15 +0100 Subject: [PATCH 014/115] Refactor s8s4_linear_cutlass() (#1545) Refactor CUTLASS-based code so it could support operators other than W4A8 --- .../s8s4_linear_cutlass.cu | 489 ++++++++++-------- 1 file changed, 267 insertions(+), 222 deletions(-) diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu index 2daefb7773..411343f0da 100644 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu @@ -29,26 +29,35 @@ namespace torchao { #if defined(BUILD_S8S4_LINEAR_CUTLASS) template< - typename ElementA, - typename ElementAScale, - typename ElementB, - typename ElementBScale, - typename ElementC, - typename ElementAccumulator, - typename ElementEpilogue, - typename ElementOutput, typename ThreadblockShape, typename WarpShape, typename InstructionShape, int NumStages, - bool use_tensor_c> -void s8s4_linear_kernel_cutlass( + typename ElementA, + typename ElementB, + typename ElementAccumulator, + typename Operator, + typename ElementAScale, + typename ElementBScale, + typename ElementC, + typename UseTensorC, + typename ElementOutput> +void s8s4_linear_kernel_cutlass_sm8x( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { + using SmArch = cutlass::arch::Sm80; + using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementEpilogue = float; + + using ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; + + constexpr auto NumEVTEpilogueStages = 1; const int m = tensor_a.size(0); const int n = tensor_b.size(0); @@ -56,13 +65,13 @@ void s8s4_linear_kernel_cutlass( constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentAScale = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentBScale = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentOutput = - 128 / cutlass::sizeof_bits::value; + 128 / cutlass::sizeof_bits::value; // Check for current CUTLASS limitations w.r.t. alignments. TORCH_CHECK(k % AlignmentA == 0, @@ -75,12 +84,6 @@ void s8s4_linear_kernel_cutlass( __func__, " : Number of columns of tensor C must be divisible ", "by ", AlignmentC); - using SmArch = cutlass::arch::Sm80; - using ThreadblockSwizzle = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; - - constexpr auto NumEVTEpilogueStages = 1; - using TensorAScaleTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< ThreadblockShape, @@ -132,9 +135,9 @@ void s8s4_linear_kernel_cutlass( cutlass::epilogue::threadblock::VisitorRowBroadcast< TensorCTileThreadMap, ElementC, - cute::Stride>; + cute::Stride>; using TensorC = - std::conditional_t; + std::conditional_t; using TensorCArguments = typename TensorC::Arguments; using ApplyAScale = cutlass::epilogue::threadblock::VisitorCompute< @@ -178,7 +181,7 @@ void s8s4_linear_kernel_cutlass( typename cutlass::gemm::kernel::DefaultGemmWithVisitor< ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, - ElementC, LayoutC, AlignmentC, + ElementOutput, LayoutOutput, AlignmentOutput, ElementAccumulator, ElementEpilogue, cutlass::arch::OpClassTensorOp, @@ -189,7 +192,7 @@ void s8s4_linear_kernel_cutlass( EVTOutput, ThreadblockSwizzle, NumStages, - cutlass::arch::OpMultiplyAddMixedInputUpcast, + Operator, NumEVTEpilogueStages >::GemmKernel; @@ -210,7 +213,7 @@ void s8s4_linear_kernel_cutlass( }; TensorCArguments tensor_c_arguments{ [&]() -> TensorCArguments { - if constexpr (use_tensor_c) { + if constexpr (UseTensorC::value) { return {(ElementC*)tensor_c.data_ptr(), ElementC(0), {cute::_0{}, cute::_1{}, problem_size.n()}}; @@ -282,127 +285,193 @@ void s8s4_linear_kernel_cutlass( // Perform mixed datatypes GEMM operation. status = gemm_op.run(at::cuda::getCurrentCUDAStream()); CUTLASS_STATUS_CHECK(status); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); } -template< - typename ElementA, - typename ElementAScale, - typename ElementB, - typename ElementBScale, - typename ElementC, - typename ElementAccumulator, - typename ElementEpilogue, - typename ElementOutput, - bool use_tensor_c> -void -s8s4_linear_cutlass_dispatch_shapes( +template +static void select_config( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm8x = dprops->major == 8; + + if (is_sm8x) { + if constexpr (std::is_same::value && + std::is_same::value) { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + // A minimal heuristic to improve performance for small number + // of inputs cases. + if (tensor_a.size(0) <= 16) { + using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; + constexpr auto NumStages = 6; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else if (tensor_a.size(0) <= 32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; + constexpr auto NumStages = 5; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + constexpr auto NumStages = 4; + s8s4_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, + ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } + return; + } + } + + TORCH_CHECK(false, + __func__, " : Operator not supported on SM", dprops->major, ".", + dprops->minor, " for given operands"); +} + +template +static void +dispatch_on_tensor_a_and_tensor_b( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - - // A minimal heuristic to improve performance for small number of - // inputs cases. - if (tensor_a.size(0) <= 16) { - using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; - constexpr auto NumStages = 6; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else if (tensor_a.size(0) <= 32) { - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; - constexpr auto NumStages = 5; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; - using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; - constexpr auto NumStages = 4; - s8s4_linear_kernel_cutlass< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, - ThreadblockShape, WarpShape, InstructionShape, NumStages, use_tensor_c>( + if (tensor_a.scalar_type() == at::ScalarType::Char) { + if (tensor_b.scalar_type() == at::ScalarType::Char) { + if (tensor_a.size(1) == 2 * tensor_b.size(1)) { + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast; + select_config< + ElementA, ElementB, ElementAccumulator, Operator, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + } + return; + } } + + TORCH_CHECK(false, + __func__, " : Operator not supported for combination of data ", + "types ", tensor_a.scalar_type(), " for first operand and ", + tensor_b.scalar_type(), " for second operand"); } -#endif -// Perform linear operation, using corresponding CUTLASS mixed -// data-types GEMM kernel, to given arguments: -// result = (input * input_scale) @ (weight * weight_scale).T + bias -// Notes: The "input_scale" tensor is expected to be a vector, of size -// equal to number of rows of "input" tensor. The "weight_scale" -// tensor is expected to be a vector, of size equal to number of rows -// of "weight" tensor. The "bias" tensor is expected to be a vector, -// of size equal to number of rows of "weight" tensor. -at::Tensor -s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, - const at::Tensor& weight, const at::Tensor& weight_scale, - const at::Tensor& bias) { -#if defined(BUILD_S8S4_LINEAR_CUTLASS) - // For now, only CC 8.x devices are supported. - const auto dprops = at::cuda::getCurrentDeviceProperties(); - const auto is_sm8x = dprops->major == 8; - TORCH_CHECK(is_sm8x, - __func__, " : Supported only on GPUs with compute capability " - "8.x"); - - // Validate datatypes of arguments. - TORCH_CHECK(input.dtype() == at::kChar, - __func__, " : The input datatype ", input.dtype(), - " not supported"); - TORCH_CHECK(input_scale.dtype() == at::kHalf || - input_scale.dtype() == at::kBFloat16, - __func__, " : The input scale datatype ", input_scale.dtype(), - " not supported"); - TORCH_CHECK(weight.dtype() == at::kChar, " : The weight datatype ", - weight.dtype(), " not supported"); - TORCH_CHECK(weight_scale.dtype() == input_scale.dtype(), - __func__, " : Expected weight scale datatype ", - input_scale.dtype(), ", got ", weight_scale.dtype()); - if (bias.numel() > 0) { - TORCH_CHECK(bias.dtype() == input_scale.dtype(), - __func__, " : Expected bias datatype ", input_scale.dtype(), - ", got ", bias.dtype()); +template +static void +dispatch_on_tensor_c( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + if (tensor_c.numel() == 0) { + using ElementC = ElementOutput; + using UseTensorC = std::false_type; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } + + using UseTensorC = std::true_type; + if (tensor_c.scalar_type() == at::ScalarType::Half) { + using ElementC = cutlass::half_t; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; + } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) { + using ElementC = cutlass::bfloat16_t; + dispatch_on_tensor_a_and_tensor_b< + ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + return; } + TORCH_CHECK(false, + __func__, " : Operator not supported for datatype ", + tensor_c.scalar_type(), " for addend"); +} + +static void +dispatch_on_tensor_a_scale_and_tensor_b_scale( + const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, + const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, + const at::Tensor& tensor_c, at::Tensor& tensor_d) { + TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(), + __func__, " : Operator not supported for output datatype ", + tensor_d.scalar_type(), " as it's different from the first ", + " operand scale datatype ", tensor_a_scale.scalar_type()); + + if (tensor_a_scale.scalar_type() == at::ScalarType::Half && + tensor_b_scale.scalar_type() == at::ScalarType::Half) { + using ElementAScale = cutlass::half_t; + using ElementBScale = cutlass::half_t; + using ElementOutput = cutlass::half_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 && + tensor_b_scale.scalar_type() == at::ScalarType::BFloat16) { + using ElementAScale = cutlass::bfloat16_t; + using ElementBScale = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + dispatch_on_tensor_c( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + return; + } + + TORCH_CHECK(false, + __func__, " : Operator not supported for combination of data ", + "types ", tensor_a_scale.scalar_type(), + " for first operand scale and ", tensor_b_scale.scalar_type(), + " for second operand scale"); +} + +void +check_inputs( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { // Validate layouts of arguments. - TORCH_CHECK(input.dim() >= 2, - __func__, " : Expected input argument to be 2D or " - "higher-dimensional tensor, got ", input.dim(), " dims"); - TORCH_CHECK(input.layout() == at::Layout::Strided, - __func__, " : Expected input argument to be strided, got layout ", - input.layout()); - TORCH_CHECK(input_scale.dim() == input.dim() - 1, - __func__, " : Expected input scale argument to be ", - input.dim() - 1, "D tensor, got ", input_scale.dim(), " dims"); - TORCH_CHECK(input_scale.layout() == at::Layout::Strided, - __func__, " : Expected input scale argument to be strided, got " - "layout ", input_scale.layout()); - TORCH_CHECK(weight.dim() == 2, - __func__, " : Expected weight argument to be 2D tensor, got ", - weight.dim(), " dims"); - TORCH_CHECK(weight.layout() == at::Layout::Strided, - __func__, - " : Expected weight argument to be strided, got layout ", - weight.layout()); - TORCH_CHECK(weight_scale.dim() == 1 || weight_scale.dim() == 2, - __func__, " : Expected weight scale argument to be 1D or 2D ", - "tensor, got ", weight_scale.dim(), " dims"); - TORCH_CHECK(weight_scale.layout() == at::Layout::Strided, - __func__, " : Expected weight scale argument to be strided, got " - "layout ", weight_scale.layout()); + TORCH_CHECK(xq.dim() >= 2, + __func__, " : Expected xq argument to be 2D or " + "higher-dimensional tensor, got ", xq.dim(), " dims"); + TORCH_CHECK(xq.layout() == at::Layout::Strided, + __func__, " : Expected xq argument to be strided, got layout ", + xq.layout()); + TORCH_CHECK(x_scale.dim() == xq.dim() - 1, + __func__, " : Expected xq scale argument to be ", xq.dim() - 1, + "D tensor, got ", x_scale.dim(), " dims"); + TORCH_CHECK(x_scale.layout() == at::Layout::Strided, + __func__, " : Expected xq scale argument to be strided, got " + "layout ", x_scale.layout()); + TORCH_CHECK(wq.dim() == 2, + __func__, " : Expected wq argument to be 2D tensor, got ", + wq.dim(), " dims"); + TORCH_CHECK(wq.layout() == at::Layout::Strided, + __func__, " : Expected wq argument to be strided, got layout ", + wq.layout()); + TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2, + __func__, " : Expected wq scale argument to be 1D or 2D tensor, ", + "got ", w_scale.dim(), " dims"); + TORCH_CHECK(w_scale.layout() == at::Layout::Strided, + __func__, " : Expected wq scale argument to be strided, got " + "layout ", w_scale.layout()); if (bias.numel() > 0) { TORCH_CHECK(bias.dim() == 1, __func__, " : Expected bias argument to be 1D tensor, got ", @@ -412,116 +481,92 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale, "layout ", bias.layout()); } - // Squash the input tensor to 2D tensor. - const auto input_sizes = input.sizes().vec(); - const auto input_2d = input.reshape({-1, input_sizes.back()}); - const auto input_scale_sizes = input_scale.sizes().vec(); - const auto input_scale_1d = input_scale.reshape({-1}); - const auto weight_scale_1d = weight_scale.reshape({-1}); - // Validate sizes of arguments. - TORCH_CHECK(input_2d.size(1) == 2 * weight.size(1), - __func__, " : Expected input argument to have ", - 2 * weight.size(1), " columns, but got ", input_2d.size(1)); - for (auto i = 0; i < input_scale_sizes.size(); ++i) - TORCH_CHECK(input_scale_sizes[i] == input_sizes[i], - __func__, " : Expected input scale argument size at position ", - i, " to be ", input_sizes[i], ", but got ", - input_scale_sizes[i]); - TORCH_CHECK(weight_scale_1d.numel() == weight.size(0), - __func__, " : Expected weight scale argument to have ", - weight.size(0), " elements, got ", weight_scale_1d.numel(), - " elements"); + const auto xq_sizes = xq.sizes().vec(); + TORCH_CHECK(xq_sizes.back() == 2 * wq.size(1), + __func__, " : Expected xq argument to have ", 2 * wq.size(1), + " columns, but got ", xq_sizes.back()); + const auto x_scale_sizes = x_scale.sizes().vec(); + for (auto i = 0; i < x_scale_sizes.size(); ++i) + TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], + __func__, " : Expected xq scale argument size at position ", + i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]); + TORCH_CHECK(w_scale.numel() == wq.size(0), + __func__, " : Expected wq scale argument to have ", wq.size(0), + " elements, got ", w_scale.numel(), " elements"); if (bias.numel() > 0) { - TORCH_CHECK(bias.numel() == weight.size(0), - __func__, " : Expected bias argument to have ", weight.size(0), + TORCH_CHECK(bias.numel() == wq.size(0), + __func__, " : Expected bias argument to have ", wq.size(0), " elements, got ", bias.numel(), " elements"); } // Validate strides of arguments. - const auto input_2d_strides = input_2d.strides(); - TORCH_CHECK(input_2d_strides[0] >= 1 && input_2d_strides[1] == 1, - __func__, " : Expected input argument in row-major layout"); - const auto input_scale_1d_strides = input_scale_1d.strides(); - TORCH_CHECK(input_scale_1d_strides[0] == 1, - __func__, " : Expected input scale argument to be contiguous"); - const auto weight_strides = weight.strides(); - TORCH_CHECK(weight_strides[0] >= 1 && weight_strides[1] == 1, - __func__, " : Expected weight argument in row-major layout"); - const auto weight_scale_1d_strides = weight_scale_1d.strides(); - TORCH_CHECK(weight_scale_1d_strides[0] == 1, - __func__, " : Expected weight scale argument to be contiguous"); + const auto xq_strides = xq.strides(); + TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1, + __func__, " : Expected xq argument in row-major layout"); + auto xq_stride_expected = xq_strides[xq_strides.size() - 2]; + for (int i = xq_strides.size() - 3; i >= 0; --i) { + xq_stride_expected *= xq_sizes[i + 1]; + TORCH_CHECK(xq_strides[i] == xq_stride_expected, + __func__, " : Expected xq argument in row-major layout"); + } + TORCH_CHECK(x_scale.is_contiguous(), + __func__, " : Expected xq scale argument to be contiguous"); + const auto wq_strides = wq.strides(); + TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1, + __func__, " : Expected wq argument in row-major layout"); + TORCH_CHECK(w_scale.is_contiguous(), + __func__, " : Expected wq scale argument to be contiguous"); if (bias.numel() > 0) { const auto bias_strides = bias.strides(); TORCH_CHECK(bias_strides[0] == 1, __func__, " : Expected bias argument to be contiguous"); } +} +#endif + +// Perform linear operation, using corresponding CUTLASS mixed +// data-types GEMM kernel, to given arguments: +// result = (xq * x_scale) @ (wq * w_scale).T + bias +// Notes: The "x_scale" tensor is expected to be a vector, of size +// equal to number of rows of "xq" tensor. The "w_scale" tensor is +// expected to be a vector, of size equal to number of rows of "wq" +// tensor. The "bias" tensor is expected to be a vector, of size equal +// to number of rows of "wq" tensor. +at::Tensor +s8s4_linear_cutlass( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { +#if defined(BUILD_S8S4_LINEAR_CUTLASS) + // Check inputs. + check_inputs(xq, x_scale, wq, w_scale, bias); + + // Squash the input tensors as appropriate. + const auto xq_sizes = xq.sizes().vec(); + const auto xq_2d = xq.reshape({-1, xq_sizes.back()}); + const auto x_scale_sizes = x_scale.sizes().vec(); + const auto x_scale_1d = x_scale.reshape({-1}); + const auto w_scale_1d = w_scale.reshape({-1}); // Introduce alias names for arguments, according to the CUTLASS // naming conventions. - const auto& tensor_a = input_2d; - const auto& tensor_a_scale = input_scale_1d; - const auto& tensor_b = weight; - const auto& tensor_b_scale = weight_scale_1d; + const auto& tensor_a = xq_2d; + const auto& tensor_a_scale = x_scale_1d; + const auto& tensor_b = wq; + const auto& tensor_b_scale = w_scale_1d; const auto& tensor_c = bias; // Create output tensor. at::Tensor tensor_d = tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); - using ElementA = int8_t; - using ElementB = cutlass::int4b_t; - using ElementAccumulator = int32_t; - AT_DISPATCH_SWITCH( - input_scale.scalar_type(), - "s8s4_linear_cutlass", - AT_DISPATCH_CASE( - at::ScalarType::Half, - [&]() { - using ElementAScale = cutlass::half_t; - using ElementBScale = cutlass::half_t; - using ElementC = cutlass::half_t; - using ElementEpilogue = float; - using ElementOutput = cutlass::half_t; - if (bias.numel() > 0) { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, true>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, false>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - }) - AT_DISPATCH_CASE( - at::ScalarType::BFloat16, - [&]() { - using ElementAScale = cutlass::bfloat16_t; - using ElementBScale = cutlass::bfloat16_t; - using ElementC = cutlass::bfloat16_t; - using ElementEpilogue = float; - using ElementOutput = cutlass::bfloat16_t; - if (bias.numel() > 0) { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, true>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } else { - s8s4_linear_cutlass_dispatch_shapes< - ElementA, ElementAScale, ElementB, ElementBScale, ElementC, - ElementAccumulator, ElementEpilogue, ElementOutput, false>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - })); - - auto tensor_d_sizes = input_sizes; - tensor_d_sizes.back() = weight.size(0); + // Dispatch to appropriate kernel template. + dispatch_on_tensor_a_scale_and_tensor_b_scale( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + + // Reshape and return output tensor. + auto tensor_d_sizes = xq_sizes; + tensor_d_sizes.back() = wq.size(0); return tensor_d.reshape(tensor_d_sizes); #else TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); From 5d1444bdef6df15eb89c4c5716ede1c5f8677798 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 21 Jan 2025 15:22:03 -0800 Subject: [PATCH 015/115] Sparsity docs update (#1590) --- docs/source/api_ref_sparsity.rst | 6 +++--- torchao/sparsity/sparse_api.py | 32 ++++++++++++++++---------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/docs/source/api_ref_sparsity.rst b/docs/source/api_ref_sparsity.rst index 8023d0bacc..33c652390d 100644 --- a/docs/source/api_ref_sparsity.rst +++ b/docs/source/api_ref_sparsity.rst @@ -12,7 +12,7 @@ torchao.sparsity WandaSparsifier PerChannelNormObserver - apply_sparse_semi_structured apply_fake_sparsity - - + sparsify_ + semi_sparse_weight + int8_dynamic_activation_int8_semi_sparse_weight diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index 3dd7971525..eb31cba619 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -43,7 +43,7 @@ def sparsify_( apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, ) -> torch.nn.Module: - """Convert the weight of linear modules in the model with `apply_tensor_subclass` + """Convert the weight of linear modules in the model with `apply_tensor_subclass`. This function is essentially the same as quantize, put for sparsity subclasses. Currently, we support three options for sparsity: @@ -54,26 +54,26 @@ def sparsify_( Args: model (torch.nn.Module): input model apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (sparsified) tensor subclass instance (e.g. affine quantized tensor instance) - filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on - the weight of the module + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module - Example:: - import torch - import torch.nn as nn - from torchao.sparsity import sparsify_ + **Example:** + :: + import torch + import torch.nn as nn + from torchao.sparsity import sparsify_ - def filter_fn(module: nn.Module, fqn: str) -> bool: - return isinstance(module, nn.Linear) + def filter_fn(module: nn.Module, fqn: str) -> bool: + return isinstance(module, nn.Linear) - m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) + m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - # for 2:4 sparsity - from torchao.sparse_api import semi_sparse_weight - m = sparsify_(m, semi_sparse_weight(), filter_fn) + # for 2:4 sparsity + from torchao.sparse_api import semi_sparse_weight + m = sparsify_(m, semi_sparse_weight(), filter_fn) - # for int8 dynamic quantization + 2:4 sparsity - from torchao.dtypes import SemiSparseLayout - m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) + # for int8 dynamic quantization + 2:4 sparsity + from torchao.dtypes import SemiSparseLayout + m = quantize_(m, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout), filter_fn) """ _replace_with_custom_fn_if_matches_filter( model, From 166a35768a60964a2415be9823d800b24ed00cf3 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 22 Jan 2025 15:46:03 -0800 Subject: [PATCH 016/115] Sparsity getting started docs (#1592) --- docs/source/index.rst | 95 +---- docs/source/sparsity.rst | 731 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 744 insertions(+), 82 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index c008c80453..3bbcd203fd 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -3,80 +3,25 @@ Welcome to the torchao Documentation `torchao `__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README `__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on: -1. API Reference -2. Developer Contribution Guide -3. Tutorials +1. Getting Started +2. Developer Notes +3. API Reference +4. Tutorials -.. - .. grid:: 3 - - .. grid-item-card:: :octicon:`file-code;1em` - Getting Started - :img-top: _static/img/card-background.svg - :link: getting-started.html - :link-type: url - - Learn about how to get started with torchao - and ts application in your projects. - - .. grid-item-card:: :octicon:`file-code;1em` - Concepts - :img-top: _static/img/card-background.svg - :link: dtypes.html - :link-type: url - - Learn about the key torchao concepts such - as dtypes, quantization, sparsity, among others. - - .. grid-item-card:: :octicon:`file-code;1em` - API Reference - :img-top: _static/img/card-background.svg - :link: api_ref_intro.html - :link-type: url - - A comprehensive reference for the torchao - API and its functionalities. - - Tutorials - ~~~~~~~~~ - - Ready to experiment? Check out some of the - torchao tutorials. - - .. customcardstart:: - - .. customcarditem:: - :header: Template Tutorial - :card_description: A placeholder template for demo purposes - :image: _static/img/generic-pytorch-logo.png - :link: tutorials/template_tutorial.html - :tags: template - - .. customcardend:: - - -.. ---------------------------------------------------------------------- -.. Below is the toctree i.e. it defines the content of the left sidebar. -.. Each of the entry below corresponds to a file.rst in docs/source/. -.. ---------------------------------------------------------------------- - -.. - .. toctree:: - :glob: - :maxdepth: 1 - :caption: Getting Started - :hidden: +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Getting Started - overview - getting-started + getting-started + sparsity - .. toctree:: - :glob: - :maxdepth: 1 - :caption: Tutorials - :hidden: +.. toctree:: + :glob: + :maxdepth: 1 + :caption: Developer Notes - tutorials/template_tutorial + contributor_guide .. toctree:: :glob: @@ -86,15 +31,6 @@ Welcome to the torchao Documentation api_ref_dtypes api_ref_quantization api_ref_sparsity -.. - api_ref_kernel - -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Contributor Guide - - contributor_guide .. toctree:: :glob: @@ -102,4 +38,3 @@ Welcome to the torchao Documentation :caption: Tutorials serialization - diff --git a/docs/source/sparsity.rst b/docs/source/sparsity.rst index 273ee5b770..0bde173b6d 100644 --- a/docs/source/sparsity.rst +++ b/docs/source/sparsity.rst @@ -1,4 +1,731 @@ Sparsity -======== +-------- -TBA +Sparsity is the technique of removing parameters from a neural network in order to reduce its memory overhead or latency. By carefully choosing how the elements are pruned, one can achieve significant reduction in memory overhead and latency, while paying a reasonably low or no price in terms of model quality (accuracy / f1). + +Goal +==== + +We feel that the main problem current sparsity researchers / users face is fragmentation. Researchers rightfully aim to show end-to-end results, but this means a lot of time is spent figuring out how to integrate with PyTorch and implementation questions like: + + +* *When should I mask?* +* *When/how should I store the compressed representation?* +* *Do I want in-place or out-of-place mask updates?* +* *How can I call sparse matmul instead of dense?* + +We feel like the above problems can be solved once by ``torchao``\ , letting researchers focus on what really matters - pushing sparse kernel performance or more accurate pruning algorithms. + +More concretely, we hope to provide tutorials and APIs for both sparse kernels (tensor subclassing) and pruning algorithms (torch.ao.pruning.Sparsifier) that users can extend. We aim to provide modular building blocks, that can be used to accelerate not only inference but training as well, and that compose nicely with ``torchao`` quantization workflows. + + +#. Train sparse models from scratch with hardware acceleration, with minimal accuracy loss. +#. Recover accuracy loss of pruned model with custom pruning algorthim. +#. Accelerate masked/pruned models on sparsity-supported hardware to realize performance improvements. + +Design +====== + +Sparsity, like quantization, is an accuracy/performance trade-off, where we care not only about the speedup but also on the accuracy degradation of our architecture optimization technique. + +In quantization, the theoretical performance gain is generally determined by the data type that we are quantizing to - quantizing from float32 to float16 yields a theoretical 2x speedup. For pruning/sparsity, the analogous variable would be the sparsity level/ sparsity pattern. For semi-structured, the sparsity level is fixed at 50%, so we expect a theoretical 2x improvement. For block-sparse matrices and unstructured sparsity, the speedup is variable and depends on the sparsity level of the tensor. + +One key difference between sparsity and quantization is in how the accuracy degradation is determined: In general, the accuracy degradation of quantization is determined by the scale and zero_point chosen. However, in pruning the accuracy degradation is determined by the mask. Sparsity and quantization are closely related and share accuracy mitigation techniques like quantization/sparsity aware training. + +By carefully choosing the specified elements and retraining the network, pruning can achieve negligible accuracy degradation and in some cases even provide a slight accuracy gain. This is an active area of research with no agreed-upon consensus. We expect users will have a target sparsity pattern and mind and to prune to that pattern. + +Given a target sparsity pattern, pruning/sparsifying a model can then be thought of as two separate subproblems: + + +* **Accuracy** - How can I find a set of sparse weights which satisfy my target sparsity pattern that minimize the accuracy degradation of my model? +* **Perforance** - How can I accelerate my sparse weights for inference and reduce memory overhead? + +Our workflow is designed to consist of two parts that answer each question independently: + + +* a frontend python user-facing API to find sparse weights for any arbitrary sparsity pattern. +* a backend collection of sparse kernels / ops to reduce memory/latency. + +The handoff point between these two pieces are sparse weights stored in a dense format, with 0 in the place of missing elements. This is a natural handoff point because sparse matrix multiplication and dense matrix multiplication with this tensor will be numerically equivalent. This lets us present a clear contract to the user for our backend, for a given sparsity pattern: + +If you can get your dense matrix into a **2:4 sparse format**, we can speed up matrix multiplication up to **1.7x** with no numerical loss. + +This also allows users with existing sparse weights in a dense format to take advantage of our fast sparse kernels. We anticipate many users to come up with their own custom frontend masking solution or to use another third party solution, as this is an active area of research. + + +.. image:: ../static/pruning_ecosystem_diagram.png + :alt: pruning_flow + + +Below, we provide an example of accelerating a model with 2:4 sparsity + bf16 using our PyTorch APIs. + +.. code-block:: python + + import torch + from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor + from torch.ao.pruning import WeightNormSparsifier + + # bfloat16 CUDA model + model = model.half().cuda() + + # Accuracy: Finding a sparse subnetwork + sparse_config = [] + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear): + sparse_config.append({"tensor_fqn": f"{name}.weight"}) + + sparsifier = WeightNormSparsifier(sparsity_level=1.0, + sparse_block_shape=(1,4), + zeros_per_block=2) + + # attach FakeSparsity + sparsifier.prepare(model, sparse_config) + sparsifier.step() + sparsifier.squash_mask() + # now we have dense model with sparse weights + + # Performance: Accelerated sparse inference + for name, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear): + mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight)) + +Fundamentally, the flow works by manipulating ``torch.Tensors``. In the frontend, we specify the tensors by their fully-qualified-name in a sparse_config dictionary. The frontend is designed to follow the quantization API, with a ``prepare`` function, which attaches FakeSparsity paramerizations to the tensors specified in the config. + +FakeSparsity is a parameterization which simulates unstructured sparsity, where each element has a mask. Because of this, we can use it to simulate any sparsity pattern we want. + +The user will then train the prepared model using their own custom code, calling ``.step()`` to update the mask if necessary. Once they’ve found a suitable mask, they call ``squash_mask()`` to fuse the mask into the weights, creating a dense tensor with 0s in the right spot. + +Users will then convert their model for accelerated sparse inference by either using the quantization flow for quantized block sparse CPU inference or by calling ``to_sparse_semi_structured`` on the specified weight tensors. + +Context +======= + +This section provides some context on neural network pruning/sparsity as well as definitions for some common pruning/sparsity terms. In academia / industry, **pruning** and **sparsity** are often used interchangeably to refer to the same thing. This can be confusing, especially since sparsity is an overloaded term that can refer to many other things, such as sparse tensor representations. + +Note that this section focuses on **pruning**, instead of **sparse training**. The distinction being that in **pruning** we start with a pretrained dense model, while during **sparse training** we train a sparse model from scratch. + +In order to avoid confusion, we generally try to use sparsity to refer to tensors. Note that a sparse tensor can refer to a dense tensor with many zero values, or a tensor stored using a sparse representation. We describe the flow as **pruning** and the resultant model as a **pruned** model. + +Roughly, the flow for achieving a more performant pruned model looks like this: + + +.. image:: ../static/pruning_flow.png + :alt: flow + + +The general idea behind pruning is that we can mask out some of the weights of a trained neural network and recover any accuracy loss. The resultant pruned model can be run on optimized kernels that take advantage of this sparsity for accelerated inference. + +Zeroing out pruned parameters doesn’t affect the latency / memory overhead of the model out of the box. This is because the dense tensor itself still contains the pruned elements (the 0 elements) and will still compute using those elements during a matrix multiply. In order to realize performance gains, we need to swap out our dense kernels for sparse kernels. + +Loosely speaking, these sparse representations allow us to skip calculations involving pruned elements in order to speed up matrix multiplication. To do this, these optimized sparse kernels work on sparse matrices that are stored in a more efficient format. Some sparse tensor layouts are tightly coupled to specific backends, like NVIDIA 2:4, while others are more general and are supported by more than one backend (CSC is supported by FBGEMM and QNNPACK). + + +.. raw:: html + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Name + Description + How the sparse matrix is stored +
COO (sparse_coo) + COOrdinate format to store sparse matrices. The matrices are stored as a combination of the non-sparse data vector and the index locations of those elements in the dense matrix. + sparse matrix = {Index: Tensor of coordinate locations, + Data: Tensor of values corresponding to index locations } +
BSR (sparse_bsr) + Block sparse row format to store sparse matrices. The matrices are stored as data blocks and the index locations of those blocks in the dense matrix. Very similar to COO, except that individual data consists of blocks, not scalars. + sparse matrix = {Index: Tensor of coordinate locations, two dimensional for a matrix, + Data: Tensor of blocks corresponding to index locations } + where a block is a matrix corresponding to the sparsity pattern. +
CSR (sparse_csr) / CSC (sparse_csc) + Compressed sparse row /column format to store sparse matrices. The sparse matrices are stored as data blocks on columns / rows and indices of those rows/columns in a dense matrix. This is the most compact format for storing block sparse matrices. + sparse_matrix = {Index: 1D tensor of column indices, + IndexPtr: 1D tensor specifying the start and end indices of columns for rows, starting from row 0, + Data: Tensor of blocks corresponding to Index locations.} +
NVIDIA 2:4 compressed representation + Custom NVIDIA compressed storage format for 2:4 semi-structured sparsity. We store the sparse matrix as a compressed dense matrix (½ the size) containing the non-pruned elements and a bitmask index. When multiplying our sparse matrix by another dense matrix, we use the mask to index into the dense matrix and multiply with our compressed dense matrix. + sparse_matrix = {Bitmask: 2bit indices of pruned elements Compressed dense matrix: contains all unpruned elements, half the size of original dense matrix} +
+ + +*Table 4.1: Overview of common sparse tensor layouts.* + +While the general idea of pruning is quite simple, there are many details that a user must figure out before they can successfully prune a model. + +These can be loosely broken down as follows: + + +* **Pruning Configuration** - What layers should I prune? What sparsity level should I prune to? +* **Pruning Criteria** - How should I decide which parameters to remove? +* **Pruning Strategy** - Once I have removed parameters, how can I recover any accuracy degradation? +* **Sparsity Pattern** - Should I try to use a specific sparsity pattern when I prune my model? Different hardware backends support accelerated inference for different sparsity patterns. + +Pruning Configuration +^^^^^^^^^^^^^^^^^^^^^ + +Not all layers in a neural network are created equal. Some layers can be more sensitive to pruning than others. The user must decide what layers to prune and also the **sparsity level** for each layer, which is the % of 0s for that weight tensor. The pruning configuration has an effect on both the accuracy and speedup of the pruned model. + +Determining the best pruning configuration and sparsity level for a given model is an open problem and a general solution does not exist. This is in part because the optimal pruning configuration is dependent on the subsequent pruning criteria and strategy, and there are an infinite number of ways to decide how to prune models and how to recover lost accuracy. + +One common method to determine which layers to prune and to what degree is to perform sensitivity analysis by pruning each layer in the model at different sparsity levels and seeing the subsequent accuracy drop (without retraining). This gives a user a sparsity-accuracy curve for each layer that the user can then use as a proxy to determine the best pruning configuration. + +Pruning Criteria +^^^^^^^^^^^^^^^^ + +A user must decide on a criteria for removing parameters from a neural network. Much like determining the best pruning configuration, determining the best pruning criteria is an open research question and is dependent on the other aforementioned factors. + +The most common pruning criteria is to use weight magnitude. The idea is that low-magnitude weights contribute less than high-magnitude weights to the model output. If we want to remove parameters, we can remove the weights that have the smallest absolute value. + +However, even with a simple pruning criteria such as weight magnitude, there are additional factors that a user would have to consider: + + +* Local vs global scope + + * **Local scope** implies that the sparsity mask is only computed with respect to the layer statistics. + + * Pros: Simple mask computing + * Cons: Potentially sub-optimal accuracy vs sparsity tradeoff. + + * **Global scope** means that the sparsity statistics are not bounded by a single layer, but can span over multiple layers if needed. + + * Pros: No need for per-layer thresholds. The tensor statistics is shared across layers, and normalization is used across layers to allow for it. + * Cons: Increased complexity when computing the masks. + +* Tensors used for mask calculation + + * **Weights**\ : Just use the weight tensor in order to calculate the mask. This method is the simplest for inference as the weight tensors are constant. + * **Gradients**\ : Compute importance based on both weights and gradient norms. Common for pre-training based methods. Currently CTR_mobile_feed uses a gradient-based pruning algorithm. + * **Activations**\ : In some research papers, the norm of the activations that are applied with the weight of interest are used to compute the importance score. + +* In place or out of place mask updates + + * **In-place** updates the sparse tensor by performing W = W (Mask). Once the weight tenosr is udpated, the sparse values are zeroed out and cannot be recovered. + + * **Pros**\ : Requires only one copy of the sparse tensor to be stored (+ mask) + * **Cons**\ : Once a mask is applied to a weight, it is zeroed out, all past history is lost. These weights cannot regrow. + + * **Out-of-place** updates don't modify the tensor directly, but perform the following: W' = W (Mask) and dW'= dW (Mask) + + * **Pros**\ : The original tensor is preserved (the masked elements are not updated via backprop). Weights can regrow if the mask changes. This is necessary for PAT. + * **Cons**\ : In addition to the unmasked weights (W), the masked weights (W’) are computed and resident in memory for forward/backward computations. + + +.. raw:: html + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Name + Description + Notes +
Magnitude / Saliency + Remove parameters that have the lowest norm (L1 is commonly used) + Shown to work well with 2:4 semi-structured sparsity. Able to achieve identical accuracy as the original model by repeating the training loop after one-shot magnitude pruning. +
Movement Pruning + These methods aim to use gradient information in order to decide what parameters to remove. The idea is to remove parameters that do not change much during fine-tuning. + Common for pretrained models. +

+ See https://arxiv.org/abs/2005.07683 +

Low-rank factorization + These methods aim to replace Wx with SQx, where S and Q are matrices with lower rank. + Usually these methods use some sort of layer-wise reconstruction, where instead of training the model to recover lost accuracy, they seek to match layer-wise statistics (Find SQx such that L2(SQx, Wx) is minimized). +
Random + Remove parameters randomly + +
+ + +*Table 4.2: Description of some common pruning criteria.* + +Pruning Strategy +^^^^^^^^^^^^^^^^ + +This is a general term that describes the method in which a user tries to recover any accuracy degradation from their pruned model. After pruning a model, it is common to see accuracy degradation of the model, so users usually retrain the pruned model in order to remediate this. The pruning strategy also determines when and how often the model is pruned during model training. + +The line between a pruning strategy and a pruning criteria is not well defined, especially in the case of pruning aware training methods, which update the mask during training. We sometimes use the term **pruning** **algorithm** to refer to the combination of these two items. These two factors, along with the pruning configuration ultimately control the final accuracy of the pruned model. + + +.. raw:: html + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Pruning Strategy + Description + Notes +
Zero-shot + Prune once, don’t retrain the model + These methods rely on more complicated pruning criteria. +

+ This is sometimes referred to as one-shot in literature, but we will use one-shot to refer to pruning once and retraining once. +

One-shot + Prune once, retrain the model once + NVIDIA has shown that one-shot 2:4 semi-structured sparsity pruning generalizes well across a range of common vision / nlp models. \ + \ + The retraining strategy is to simply repeat the training process again. +
Iterative + Prune the model, retrain, repeat + We can iteratively increase the sparsity level, or iteratively prune different layers in the model. +
Pruning Aware Training + Mask is learned during training + Used by CTR_feed for their current pruning algorithm. +
NAS / Multimask + Multiple masks are used during training. This can be thought of a form of neural architecture search. + Used by PySpeech (FastNAS) +
Layer-wise reconstruction + Instead of retraining using a loss function, we try to recover as much information as possible from each layer by using a two model approach similar to knowledge distillation. + See https://arxiv.org/pdf/2204.09656.pdf +
+ + +*Table 4.3: Description of some common pruning strategies.* + +Sparsity Pattern +^^^^^^^^^^^^^^^^ + +A sparsity pattern describes how the pruned parameters are arranged within the model / tensor. + +Recall that in general it is necessary to use optimized sparse kernels in order to achieve performance gains. Depending on the format and the sparsity level of the weight tensor, sparse matrix multiplication can be faster than its dense counterpart. It can also be slower if a tensor is not sufficiently sparse. + +At the most general level, pruning is unstructured -every parameter has it’s own mask. This gives the most flexibility but requires very high sparsity (>98%) in order to provide performance benefits. In order to provide accelerated inference at lower sparsity levels, hardware backends have added support for special sparsity patterns. + +We seek to prune the model so that the weight tensors exhibit the same sparsity pattern as our inference backend. If we are able to recover the accuracy lost while maintaining the sparsity pattern, we can run this model on sparse hardware for accelerated inference without an accuracy penalty. We can also run a model pruned to a different sparsity pattern on our target backend, at the expense of some additional accuracy loss. + +The specific backend hardware and its corresponding sparsity pattern, as well as the pruning configuration ultimately dictates the performance speedups that we observe. If we prune a model using a different pruning criteria it will have the same performance characteristics if it follows the same sparsity pattern and sparsity level. For example, if we decided to remove the highest-magnitude weights instead of the lowest-magnitude weights, we wouldn’t expect that to change the performance characteristics of the pruned model. + + +.. raw:: html + + + + + + + + + + + + + + + + + + + + + + +
Sparsity Pattern + Mask Visualization +

+ (50% sparsity level) +

Unstructured Sparsity + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Fig 2.3: unstructured sparsity +
1 + 0 + 1 + 1 + 0 + 1 + 0 + 1 +
0 + 0 + 1 + 1 + 1 + 1 + 1 + 0 +
1 + 0 + 0 + 0 + 1 + 0 + 1 + 0 +
0 + 1 + 1 + 0 + 0 + 0 + 0 + 1 +
+ + +
2:4 Semi-Structured + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Fig 2.4: 2:4 semi-structured sparsity +
0 + 1 + 1 + 0 + 1 + 0 + 1 + 0 +
0 + 0 + 1 + 1 + 1 + 1 + 0 + 0 +
1 + 0 + 0 + 1 + 0 + 1 + 0 + 1 +
0 + 1 + 0 + 1 + 1 + 0 + 1 + 0 +
+ +
Block Sparsity + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Fig 2.5: 4x4 block-wise structured sparsity +
0 + 0 + 0 + 0 + 1 + 1 + 1 + 1 +
0 + 0 + 0 + 0 + 1 + 1 + 1 + 1 +
0 + 0 + 0 + 0 + 1 + 1 + 1 + 1 +
0 + 0 + 0 + 0 + 1 + 1 + 1 + 1 +
+ +
Structured Sparsity + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Fig 2.6: row-wise structured sparsity +
1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 +
0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 +
1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 +
0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 +
+
+ +*Table 4.4: Description of some common sparsity patterns.* + +For more information on our supported APIs and benchmaks please refer `Sparsity README `_. From 602ba86e3fbff201bc32e4e8e74b9fe89321f9e2 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 23 Jan 2025 08:09:59 -0800 Subject: [PATCH 017/115] gate sparsity tests by presence of cusparselt (#1602) Summary: I have a PyTorch build without `cuSparseLt`. Adding logic to properly skip tests which depend on this library being available. Test Plan: Local testing on an H100 without cuSparseLt: ``` pytest test/prototype/test_sparse_api.py -s ``` Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_affine_quantized.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index f08ba7aa72..8be0652e9a 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -23,6 +23,10 @@ is_sm_at_least_89, ) +is_cusparselt_available = ( + hasattr(torch.backends, "cusparselt") and torch.backends.cusparselt.is_available() +) + def get_quantization_functions( do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False @@ -91,7 +95,8 @@ def test_tensor_core_layout_transpose(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize( - "apply_quant", get_quantization_functions(True, True, "cuda", True) + "apply_quant", + get_quantization_functions(is_cusparselt_available, True, "cuda", True), ) def test_weights_only(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") @@ -168,7 +173,9 @@ def apply_uint6_weight_only_quant(linear): deregister_aqt_quantized_linear_dispatch(dispatch_condition) - @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) + @common_utils.parametrize( + "apply_quant", get_quantization_functions(is_cusparselt_available, True) + ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_print_quantized_module(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") From d0e434c8d825f7ac69e26585cb2ceb002a287f24 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Thu, 23 Jan 2025 12:43:34 -0500 Subject: [PATCH 018/115] Fix broken link on doc page (#1582) --- docs/source/_templates/layout.html | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index 6bb2207266..f1d3173de2 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -2,7 +2,7 @@ {% block sidebartitle %} {% include "searchbox.html" %} {% endblock %} @@ -22,7 +22,7 @@ // to point to the torchao repo. var overwrite = function (_) { if ($(this).length > 0) { - $(this)[0].href = "https://github.com/pytorch-labs/ao" + $(this)[0].href = "https://github.com/pytorch/ao" } } // PC From e53edaa8a0d31bfc10d5a184c0178787e1a011ac Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 23 Jan 2025 12:02:44 -0800 Subject: [PATCH 019/115] pin nightlies to 20250122 (#1608) Summary: There are test failures with the 20250123 nightly: ``` if not output_graph.export: if not self.guard_manager.check(output_graph.local_scope): reasons = get_guard_fail_reason_helper( self.guard_manager, # type: ignore[arg-type] output_graph.local_scope, CompileContext.current_compile_id(), ) > raise AssertionError(f"Guard check failed: {reasons}") E AssertionError: Guard check failed: 0/0: ___check_metadata_140011526812544_c0/0 E E E You can suppress this exception and fall back to eager by setting: E import torch._dynamo E torch._dynamo.config.suppress_errors = True /home/vasiliy/.conda/envs/pt_nightly_20241006/lib/python3.11/site-packages/torch/_dynamo/guards.py:2468: AssertionError ``` full example: https://ossci-raw-job-status.s3.amazonaws.com/log/pytorch/ao/36071578472 Pin to the previous day for now until the problem is fixed in pytorch/pytorch Test Plan: Reviewers: Subscribers: Tasks: Tags: --- .github/workflows/float8_test.yml | 4 ++-- .github/workflows/nightly_smoke_test.yml | 4 ++-- .github/workflows/regression_test.yml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/float8_test.yml b/.github/workflows/float8_test.yml index 7c9e5a4b00..b77a50ed2c 100644 --- a/.github/workflows/float8_test.yml +++ b/.github/workflows/float8_test.yml @@ -25,9 +25,9 @@ jobs: include: - name: SM-89 runs-on: linux.g6.4xlarge.experimental.nvidia.gpu - torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' + torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cu124' gpu-arch-type: "cuda" - gpu-arch-version: "12.1" + gpu-arch-version: "12.4" permissions: id-token: write diff --git a/.github/workflows/nightly_smoke_test.yml b/.github/workflows/nightly_smoke_test.yml index 18d4f41af6..57486bf58f 100644 --- a/.github/workflows/nightly_smoke_test.yml +++ b/.github/workflows/nightly_smoke_test.yml @@ -21,9 +21,9 @@ jobs: include: - name: CUDA Nightly runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' + torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cu124' gpu-arch-type: "cuda" - gpu-arch-version: "12.1" + gpu-arch-version: "12.4" permissions: id-token: write diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 19c033c4d1..14c31014c3 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -25,12 +25,12 @@ jobs: include: - name: CUDA Nightly runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu124' + torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cu124' gpu-arch-type: "cuda" gpu-arch-version: "12.4" - name: CPU Nightly runs-on: linux.4xlarge - torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu' + torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cpu' gpu-arch-type: "cpu" gpu-arch-version: "" From 52280bbb69e29ccde28b529157e313f849bd9ff0 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 23 Jan 2025 15:59:23 -0800 Subject: [PATCH 020/115] [BE] Only run docs build in CI if docs have changed (#1589) only run docs build in CI if docs have changed --- .github/workflows/doc_build.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/doc_build.yml b/.github/workflows/doc_build.yml index 19c1204e6d..d16ed0340b 100644 --- a/.github/workflows/doc_build.yml +++ b/.github/workflows/doc_build.yml @@ -9,6 +9,9 @@ on: tags: - v[0-9]+.[0-9]+.[0-9] - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + paths: + - 'docs/**' + - '!docs/**' pull_request: workflow_dispatch: From 2d4c8482d306c18796fb6d478fac2bcc410f9487 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 23 Jan 2025 16:00:48 -0800 Subject: [PATCH 021/115] [float8nocompile] Add float8nocompile CI tests which only trigger on relevant code changes (#1570) add float8nocompile CI tests --- .github/workflows/float8nocompile_test.yaml | 55 +++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 .github/workflows/float8nocompile_test.yaml diff --git a/.github/workflows/float8nocompile_test.yaml b/.github/workflows/float8nocompile_test.yaml new file mode 100644 index 0000000000..75df32a5d4 --- /dev/null +++ b/.github/workflows/float8nocompile_test.yaml @@ -0,0 +1,55 @@ +name: Run Float8nocompile Tests + +on: + push: + branches: + - main + - 'gh/**' + paths: + - 'torchao/prototype/float8nocompile/**' + - '!torchao/prototype/float8nocompile/**' + pull_request: + branches: + - main + - 'gh/**' + paths: + - 'torchao/prototype/float8nocompile/**' + - '!torchao/prototype/float8nocompile/**' + +concurrency: + group: floatnocompile_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + +jobs: + test: + strategy: + fail-fast: false + matrix: + include: + - name: SM-89 + runs-on: linux.g6.4xlarge.experimental.nvidia.gpu + torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121' + gpu-arch-type: "cuda" + gpu-arch-version: "12.1" + + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + timeout: 300 + runner: ${{ matrix.runs-on }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive + script: | + conda create -n venv python=3.9 -y + conda activate venv + export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH + python -m pip install --upgrade pip + pip install ${{ matrix.torch-spec }} + pip install -r dev-requirements.txt + pip install . + cd torchao/prototype/float8nocompile + pytest kernels/ --verbose -s + pytest test/train_test.py --verbose -s From 4ed93b996b0dc9abd6ac105fec7c9fa52e9a23b3 Mon Sep 17 00:00:00 2001 From: Xia Weiwen Date: Thu, 23 Jan 2025 17:47:20 -0800 Subject: [PATCH 022/115] [CPU] Fix registration of int4wo linear implementation on CPU (#1578) * [CPU] Fix registration of int4wo linear implementation on CPU * Fix format issues * Fix format issues (2) * Fix bug for 3d input * fix format issue * Remove autocast from UT --- test/quantization/test_quant_api.py | 22 +++++ torchao/dtypes/affine_quantized_tensor_ops.py | 8 ++ torchao/dtypes/uintx/int4_cpu_layout.py | 86 ++++++++++++++++++- .../dtypes/uintx/tensor_core_tiled_layout.py | 12 +-- torchao/quantization/quant_api.py | 4 +- 5 files changed, 118 insertions(+), 14 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 177c357047..caba1cf31f 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -761,6 +761,28 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + def test_int4wo_cpu(self, dtype, x_dim): + from torchao.dtypes import Int4CPULayout + + device = "cpu" + m = ToyLinearModel().eval().to(dtype).to(device) + example_inputs = m.example_inputs(dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + + with torch.no_grad(): + quantize_(m, int4_weight_only(group_size=32, layout=Int4CPULayout())) + # ensure the expected op is in the code + _, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + assert "_weight_int4pack_mm_for_cpu" in code[0] + assert "aten.mm.default" not in code[0] + class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 76df949852..ef8691699e 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -28,6 +28,10 @@ _linear_fp_act_int4_weight_gemlite_check, _linear_fp_act_int4_weight_gemlite_impl, ) +from torchao.dtypes.uintx.int4_cpu_layout import ( + _linear_fp_act_uint4_weight_cpu_check, + _linear_fp_act_uint4_weight_cpu_impl, +) from torchao.dtypes.uintx.marlin_qqq_tensor import ( _linear_int8_act_int4_weight_marlin_qqq_check, _linear_int8_act_int4_weight_marlin_qqq_impl, @@ -151,6 +155,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ), + ( + _linear_fp_act_uint4_weight_cpu_check, + _linear_fp_act_uint4_weight_cpu_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 248f7e1b94..7c734a8a44 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -2,10 +2,17 @@ from typing import Optional, Tuple import torch -from torch.utils._python_dispatch import return_and_correct_aliasing +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) -from torchao.dtypes.affine_quantized_tensor import register_layout +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device +from torchao.quantization.quant_primitives import ZeroPointDomain from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, @@ -126,7 +133,7 @@ def from_plain( zero_point = zero_point.reshape(int_data.shape[0], -1) from torchao.quantization.utils import pack_tinygemm_scales_and_zeros - scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) return cls(packed_weight, scale_and_zero, False, _layout) def to(self, *args, **kwargs): @@ -231,7 +238,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: groupsize = int(original_shape[1] / scale.shape[-2]) block_size = (1, groupsize) device = self.device - original_dtype = torch.bfloat16 + original_dtype = self.scale_and_zero.dtype target_dtype = torch.int32 quant_min = 0 quant_max = 15 @@ -261,3 +268,74 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def get_layout(self) -> Layout: return self._layout + + +def _aqt_is_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.uint8 + and aqt.quant_min == 0 + and aqt.quant_max == 15 + ) + + +def _is_float(dtype): + return dtype in (torch.float, torch.half, torch.bfloat16) + + +def _linear_fp_act_uint4_weight_cpu_check(input_tensor, weight_tensor, bias): + return ( + TORCH_VERSION_AT_LEAST_2_6 + and is_device(input_tensor.device.type, "cpu") + and is_device(weight_tensor.device.type, "cpu") + and (bias is None or is_device(bias.device.type, "cpu")) + and not is_traceable_wrapper_subclass(input_tensor) + and _is_float(input_tensor.dtype) + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_uint4(weight_tensor) + and _is_float(weight_tensor.dtype) + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT + and isinstance(weight_tensor._layout, Int4CPULayout) + ) + + +def _linear_fp_act_uint4_weight_cpu_impl(input_tensor, weight_tensor, bias): + assert ( + TORCH_VERSION_AT_LEAST_2_6 + ), f"Requires PyTorch version at least 2.6, but got: {torch.__version__}" + assert is_device( + input_tensor.device.type, "cpu" + ), f"For CPU device only but got: {input_tensor.device}" + assert ( + weight_tensor.block_size[0] == 1 + ), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + act_mat = input_tensor + packed_weight = weight_tensor.tensor_impl.packed_weight + scale_and_zero = weight_tensor.tensor_impl.scale_and_zero + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + y = torch.ops.aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + if bias is not None: + y += bias + return y.to(orig_dtype) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 7de869df2d..378744e7e1 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -15,7 +15,6 @@ from torchao.quantization.quant_primitives import ZeroPointDomain, _get_reduction_params from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, fill_defaults, find_multiple, ) @@ -76,14 +75,9 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - if is_device(input_tensor.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: - y = torch.ops.aten._weight_int4pack_mm_for_cpu( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) - else: - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b2eff196fd..3a73b97ad1 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -725,7 +725,9 @@ def apply_int4_weight_only_quant(weight): quant_max = 15 eps = 1e-6 preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] - zero_point_dtype = torch.bfloat16 + zero_point_dtype = ( + weight.dtype if isinstance(layout, Int4CPULayout) else torch.bfloat16 + ) nonlocal zero_point_domain assert ( From 0fae69377ea9ec7e16e2e27f489e7b8c9c992b5c Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 24 Jan 2025 10:06:25 -0800 Subject: [PATCH 023/115] Add H100 to Float8 CI for testing (#1575) --- .github/workflows/float8_test.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/float8_test.yml b/.github/workflows/float8_test.yml index b77a50ed2c..3cf2d13933 100644 --- a/.github/workflows/float8_test.yml +++ b/.github/workflows/float8_test.yml @@ -28,6 +28,11 @@ jobs: torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/cu124' gpu-arch-type: "cuda" gpu-arch-version: "12.4" + - name: H100 + runs-on: linux.aws.h100 + torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124' + gpu-arch-type: "cuda" + gpu-arch-version: "12.4" permissions: id-token: write From 4e4f4df091ce50d1a97a34f156f4b667f894aac4 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 24 Jan 2025 13:43:51 -0500 Subject: [PATCH 024/115] Add quick start guide for first time users (#1611) Documentation in torchao has been pretty low-level and geared towards developers so far. This commit adds a basic quick start guide for first time users to get familiar with our main quantization flow. --- .gitignore | 2 +- docs/source/contributor_guide.rst | 2 +- docs/source/getting-started.rst | 4 - docs/source/index.rst | 17 ++-- docs/source/overview.rst | 4 - docs/source/quantization.rst | 6 +- docs/source/quick_start.rst | 136 ++++++++++++++++++++++++++++++ docs/source/sparsity.rst | 6 +- scripts/quick_start.py | 61 ++++++++++++++ 9 files changed, 213 insertions(+), 25 deletions(-) delete mode 100644 docs/source/getting-started.rst delete mode 100644 docs/source/overview.rst create mode 100644 docs/source/quick_start.rst create mode 100644 scripts/quick_start.py diff --git a/.gitignore b/.gitignore index 5fa7064cbe..726d2976f6 100644 --- a/.gitignore +++ b/.gitignore @@ -262,7 +262,7 @@ docs/dev docs/build docs/source/tutorials/* docs/source/gen_modules/* -docs/source/sg_execution_times +docs/source/sg_execution_times.rst # LevelDB files *.sst diff --git a/docs/source/contributor_guide.rst b/docs/source/contributor_guide.rst index a69c410e6c..e76b9420d0 100644 --- a/docs/source/contributor_guide.rst +++ b/docs/source/contributor_guide.rst @@ -1,4 +1,4 @@ -torchao Contributor Guide +Contributor Guide ------------------------- .. toctree:: diff --git a/docs/source/getting-started.rst b/docs/source/getting-started.rst deleted file mode 100644 index 70ac60b4a0..0000000000 --- a/docs/source/getting-started.rst +++ /dev/null @@ -1,4 +0,0 @@ -Getting Started -=============== - -TBA diff --git a/docs/source/index.rst b/docs/source/index.rst index 3bbcd203fd..04a53ce454 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,26 +1,25 @@ Welcome to the torchao Documentation -======================================= +==================================== -`torchao `__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README `__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on: - -1. Getting Started -2. Developer Notes -3. API Reference -4. Tutorials +`torchao `__ is a library for custom data types and optimizations. +Quantize and sparsify weights, gradients, optimizers, and activations for inference and training +using native PyTorch. Please checkout torchao `README `__ +for an overall introduction to the library and recent highlight and updates. .. toctree:: :glob: :maxdepth: 1 :caption: Getting Started - getting-started - sparsity + quick_start .. toctree:: :glob: :maxdepth: 1 :caption: Developer Notes + quantization + sparsity contributor_guide .. toctree:: diff --git a/docs/source/overview.rst b/docs/source/overview.rst deleted file mode 100644 index 4c6d532067..0000000000 --- a/docs/source/overview.rst +++ /dev/null @@ -1,4 +0,0 @@ -Overview -======== - -TBA diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index d96a3afc18..b5e34780b7 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -1,4 +1,4 @@ -Quantization -============ +Quantization Overview +--------------------- -TBA +Coming soon! diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst new file mode 100644 index 0000000000..fea8bb912d --- /dev/null +++ b/docs/source/quick_start.rst @@ -0,0 +1,136 @@ +Quick Start Guide +----------------- + +In this quick start guide, we will explore how to perform basic quantization using torchao. +First, install the latest stable torchao release:: + + pip install torchao + +If you prefer to use the nightly release, you can install torchao using the following +command instead:: + + pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 + +torchao is compatible with the latest 3 major versions of PyTorch, which you will also +need to install (`detailed instructions `__):: + + pip install torch + + +First Quantization Example +========================== + +The main entry point for quantization in torchao is the `quantize_ `__ API. +This function mutates your model inplace to insert the custom quantization logic based +on what the user configures. All code in this guide can be found in this `example script `__. +First, let's set up our toy model: + +.. code:: py + + import copy + import torch + + class ToyLinearModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + + # Optional: compile model for faster inference and generation + model = torch.compile(model, mode="max-autotune", fullgraph=True) + model_bf16 = copy.deepcopy(model) + +Now we call our main quantization API to quantize the linear weights +in the model to int4 inplace. More specifically, this applies uint4 +weight-only asymmetric per-group quantization, leveraging the +`tinygemm int4mm CUDA kernel `__ +for efficient mixed dtype matrix multiplication: + +.. code:: py + + # torch 2.4+ only + from torchao.quantization import int4_weight_only, quantize_ + quantize_(model, int4_weight_only(group_size=32)) + +The quantized model is now ready to use! Note that the quantization +logic is inserted through tensor subclasses, so there is no change +to the overall model structure; only the weights tensors are updated, +but `nn.Linear` modules stay as `nn.Linear` modules: + +.. code:: py + + >>> model.linear1 + Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15)) + + >>> model.linear2 + Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15)) + +First, verify that the int4 quantized model is roughly a quarter of +the size of the original bfloat16 model: + +.. code:: py + + >>> import os + >>> torch.save(model, "/tmp/int4_model.pt") + >>> torch.save(model_bf16, "/tmp/bfloat16_model.pt") + >>> int4_model_size_mb = os.path.getsize("/tmp/int4_model.pt") / 1024 / 1024 + >>> bfloat16_model_size_mb = os.path.getsize("/tmp/bfloat16_model.pt") / 1024 / 1024 + + >>> print("int4 model size: %.2f MB" % int4_model_size_mb) + int4 model size: 1.25 MB + + >>> print("bfloat16 model size: %.2f MB" % bfloat16_model_size_mb) + bfloat16 model size: 4.00 MB + +Next, we demonstrate that not only is the quantized model smaller, +it is also much faster! + +.. code:: py + + from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + benchmark_model, + unwrap_tensor_subclass, + ) + + # Temporary workaround for tensor subclass + torch.compile + # Only needed for torch version < 2.5 + if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(model) + + num_runs = 100 + torch._dynamo.reset() + example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) + bf16_time = benchmark_model(model_bf16, num_runs, example_inputs) + int4_time = benchmark_model(model, num_runs, example_inputs) + + print("bf16 mean time: %0.3f ms" % bf16_time) + print("int4 mean time: %0.3f ms" % int4_time) + print("speedup: %0.1fx" % (bf16_time / int4_time)) + +On a single A100 GPU with 80GB memory, this prints:: + + bf16 mean time: 30.393 ms + int4 mean time: 4.410 ms + speedup: 6.9x + + +Next Steps +========== + +In this quick start guide, we learned how to quantize a simple model with +torchao. To learn more about the different workflows supported in torchao, +see our main `README `__. +For a more detailed overview of quantization in torchao, visit +`this page `__. + +Finally, if you would like to contribute to torchao, don't forget to check +out our `contributor guide `__ and our list of +`good first issues `__ on Github! diff --git a/docs/source/sparsity.rst b/docs/source/sparsity.rst index 0bde173b6d..d9986a3227 100644 --- a/docs/source/sparsity.rst +++ b/docs/source/sparsity.rst @@ -1,5 +1,5 @@ -Sparsity --------- +Sparsity Overview +----------------- Sparsity is the technique of removing parameters from a neural network in order to reduce its memory overhead or latency. By carefully choosing how the elements are pruned, one can achieve significant reduction in memory overhead and latency, while paying a reasonably low or no price in terms of model quality (accuracy / f1). @@ -38,7 +38,7 @@ Given a target sparsity pattern, pruning/sparsifying a model can then be thought * **Accuracy** - How can I find a set of sparse weights which satisfy my target sparsity pattern that minimize the accuracy degradation of my model? -* **Perforance** - How can I accelerate my sparse weights for inference and reduce memory overhead? +* **Performance** - How can I accelerate my sparse weights for inference and reduce memory overhead? Our workflow is designed to consist of two parts that answer each question independently: diff --git a/scripts/quick_start.py b/scripts/quick_start.py new file mode 100644 index 0000000000..f2e195fd7e --- /dev/null +++ b/scripts/quick_start.py @@ -0,0 +1,61 @@ +import copy + +import torch + +from torchao.quantization import int4_weight_only, quantize_ +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + benchmark_model, + unwrap_tensor_subclass, +) + +# ================ +# | Set up model | +# ================ + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + +# Optional: compile model for faster inference and generation +model = torch.compile(model, mode="max-autotune", fullgraph=True) +model_bf16 = copy.deepcopy(model) + + +# ======================== +# | torchao quantization | +# ======================== + +# torch 2.4+ only +quantize_(model, int4_weight_only(group_size=32)) + + +# ============= +# | Benchmark | +# ============= + +# Temporary workaround for tensor subclass + torch.compile +# Only needed for torch version < 2.5 +if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(model) + +num_runs = 100 +torch._dynamo.reset() +example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) +bf16_time = benchmark_model(model_bf16, num_runs, example_inputs) +int4_time = benchmark_model(model, num_runs, example_inputs) + +print("bf16 mean time: %0.3f ms" % bf16_time) +print("int4 mean time: %0.3f ms" % int4_time) +print("speedup: %0.1fx" % (bf16_time / int4_time)) From 70be2452f3ae4fbd13ab61609732878baa990c84 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 24 Jan 2025 11:27:48 -0800 Subject: [PATCH 025/115] Move fpx to tensor subclass (#1603) --- torchao/dtypes/__init__.py | 6 +- torchao/dtypes/affine_quantized_tensor.py | 87 +++++-------------- torchao/dtypes/floatx/__init__.py | 4 + .../floatx/floatx_tensor_core_layout.py | 57 ++++++++++++ 4 files changed, 87 insertions(+), 67 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 9cbd4cd2a0..d043a13af9 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -4,12 +4,14 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, # experimental, will be merged into floatx in the future - to_affine_quantized_fpx, to_affine_quantized_intx, to_affine_quantized_intx_static, ) from .floatx import ( Float8Layout, + FloatxTensor, + FloatxTensorCoreLayout, + to_affine_quantized_fpx, ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( @@ -52,4 +54,6 @@ "MarlinQQQLayout", "Int4CPULayout", "CutlassInt4PackedLayout", + "FloatxTensor", + "FloatxTensorCoreLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e7aca34c5f..eedca7e1cb 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -14,12 +14,9 @@ MappingType, ZeroPointDomain, choose_qparams_affine, - choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, dequantize_affine, - dequantize_affine_floatx, quantize_affine, - quantize_affine_floatx, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -36,7 +33,6 @@ "to_affine_quantized_floatx", "to_affine_quantized_intx_static", "to_affine_quantized_floatx_static", - "to_affine_quantized_fpx", ] @@ -126,40 +122,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if output_dtype is None: output_dtype = self.dtype - from torchao.dtypes.floatx import FloatxTensorCoreLayout - - if isinstance(self._layout, FloatxTensorCoreLayout): - int_data, scale = self.tensor_impl.get_plain() - return dequantize_affine_floatx( - int_data, - scale, - self._layout.ebits, - self._layout.mbits, - output_dtype=output_dtype, - ) - else: - data, scale, zero_point = self.tensor_impl.get_plain() - dq = dequantize_affine( - data, - self.block_size, - scale, - zero_point, - data.dtype, - self.quant_min, - self.quant_max, - self.zero_point_domain, - output_dtype=output_dtype, - ) - from torchao.dtypes.uintx import TensorCoreTiledLayout + data, scale, zero_point = self.tensor_impl.get_plain() + dq = dequantize_affine( + data, + self.block_size, + scale, + zero_point, + data.dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain, + output_dtype=output_dtype, + ) + from torchao.dtypes.uintx import TensorCoreTiledLayout - if isinstance(self._layout, TensorCoreTiledLayout): - # need to return to original shape if tensor was padded - # in preprocessing - # TODO: we could add an API for this if there are more use cases - # (e.g. dequant_post_process) in TensorImpl or Layout - for dim, dim_size in enumerate(self.shape): - dq = dq.narrow(dim, 0, dim_size) - return dq + if isinstance(self._layout, TensorCoreTiledLayout): + # need to return to original shape if tensor was padded + # in preprocessing + # TODO: we could add an API for this if there are more use cases + # (e.g. dequant_post_process) in TensorImpl or Layout + for dim, dim_size in enumerate(self.shape): + dq = dq.narrow(dim, 0, dim_size) + return dq def __tensor_flatten__(self): return ["tensor_impl"], [ @@ -395,33 +379,6 @@ def from_hp_to_floatx_static( f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static" ) - @classmethod - def from_hp_to_fpx( - cls, - input_float: torch.Tensor, - _layout: Layout, - ): - from torchao.dtypes.floatx import FloatxTensorCoreLayout - - assert isinstance( - _layout, FloatxTensorCoreLayout - ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" - original_shape = input_float.shape - input_float = _layout.pre_process(input_float) - # per axis quantization, where axis = 1 - block_size = list(input_float.shape) - block_size[1] = 1 - - ebits, mbits = _layout.ebits, _layout.mbits - # Note: these ops are hardcoded to have per axis quantization (axis=1) right now - scale = choose_qparams_affine_floatx(input_float, ebits, mbits) - floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) - floatx_packed = _layout.post_process(floatx_unpacked) - - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) - return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) - @property def _layout(self) -> Layout: return self.tensor_impl._layout @@ -477,8 +434,6 @@ def _apply_fn_to_data(self, fn): to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static -# experimental will be merged in to floatx -to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx if TORCH_VERSION_AT_LEAST_2_5: # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 3f0a1ccd5c..4bfaa3de9e 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,7 +1,9 @@ from .float8_layout import Float8Layout from .floatx_tensor_core_layout import ( + FloatxTensor, FloatxTensorCoreLayout, from_scaled_tc_floatx, + to_affine_quantized_fpx, to_scaled_tc_floatx, ) @@ -10,4 +12,6 @@ "to_scaled_tc_floatx", "from_scaled_tc_floatx", "Float8Layout", + "to_affine_quantized_fpx", + "FloatxTensor", ] diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index 0f67e9826e..99d07fd4e0 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -11,6 +11,7 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, + get_tensor_impl_constructor, register_layout, ) from torchao.dtypes.utils import ( @@ -22,6 +23,11 @@ _floatx_unpacked_to_f32, _n_ones, ) +from torchao.quantization.quant_primitives import ( + choose_qparams_affine_floatx, + dequantize_affine_floatx, + quantize_affine_floatx, +) aten = torch.ops.aten _ONES_TABLE = [_n_ones(i) for i in range(8)] @@ -456,6 +462,54 @@ class FloatxTensorCoreLayout(Layout): mbits: int +class FloatxTensor(AffineQuantizedTensor): + """ + Floatx quantized tensor subclass which inherits AffineQuantizedTensor class. It uses floating-point format defined by ebits (exponent bits) and mbits (mantissa bits) and supports float1 - float7 tensor types. + For details about float8 tensor type, please refer to https://github.com/pytorch/ao/blob/main/torchao/dtypes/floatx/float8_layout.py. + + To see what happens during choose_qparams_and_quantize_affine_fpx, quantization and dequantization for floatx quantization, + please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py + and check the two quant primitive ops: choose_qparams_affine_floatx, quantize_affine_floatx and dequantize_affine_floatx. + """ + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + int_data, scale = self.tensor_impl.get_plain() + return dequantize_affine_floatx( + int_data, + scale, + self._layout.ebits, + self._layout.mbits, + output_dtype=output_dtype, + ) + + @classmethod + def from_hp_to_floatx( + cls, + input_float: torch.Tensor, + _layout: Layout, + ): + assert isinstance( + _layout, FloatxTensorCoreLayout + ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + # per axis quantization, where axis = 1 + block_size = list(input_float.shape) + block_size[1] = 1 + + ebits, mbits = _layout.ebits, _layout.mbits + # Note: these ops are hardcoded to have per axis quantization (axis=1) right now + scale = choose_qparams_affine_floatx(input_float, ebits, mbits) + floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) + floatx_packed = _layout.post_process(floatx_unpacked) + + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) + return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) + + @register_layout(FloatxTensorCoreLayout) class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), @@ -657,3 +711,6 @@ def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): out += bias return out.view(*act.shape[:-1], out_dim).to(act.dtype) + + +to_affine_quantized_fpx = FloatxTensor.from_hp_to_floatx From fb335e08f1c970f3c9b1f0eb7d214cfeded7fbaf Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 24 Jan 2025 11:57:33 -0800 Subject: [PATCH 026/115] Revert "Move fpx to tensor subclass" (#1616) Revert "Move fpx to tensor subclass (#1603)" This reverts commit 70be2452f3ae4fbd13ab61609732878baa990c84. --- torchao/dtypes/__init__.py | 6 +- torchao/dtypes/affine_quantized_tensor.py | 87 ++++++++++++++----- torchao/dtypes/floatx/__init__.py | 4 - .../floatx/floatx_tensor_core_layout.py | 57 ------------ 4 files changed, 67 insertions(+), 87 deletions(-) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d043a13af9..9cbd4cd2a0 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -4,14 +4,12 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, # experimental, will be merged into floatx in the future + to_affine_quantized_fpx, to_affine_quantized_intx, to_affine_quantized_intx_static, ) from .floatx import ( Float8Layout, - FloatxTensor, - FloatxTensorCoreLayout, - to_affine_quantized_fpx, ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( @@ -54,6 +52,4 @@ "MarlinQQQLayout", "Int4CPULayout", "CutlassInt4PackedLayout", - "FloatxTensor", - "FloatxTensorCoreLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index eedca7e1cb..e7aca34c5f 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -14,9 +14,12 @@ MappingType, ZeroPointDomain, choose_qparams_affine, + choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, dequantize_affine, + dequantize_affine_floatx, quantize_affine, + quantize_affine_floatx, ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, @@ -33,6 +36,7 @@ "to_affine_quantized_floatx", "to_affine_quantized_intx_static", "to_affine_quantized_floatx_static", + "to_affine_quantized_fpx", ] @@ -122,28 +126,40 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if output_dtype is None: output_dtype = self.dtype - data, scale, zero_point = self.tensor_impl.get_plain() - dq = dequantize_affine( - data, - self.block_size, - scale, - zero_point, - data.dtype, - self.quant_min, - self.quant_max, - self.zero_point_domain, - output_dtype=output_dtype, - ) - from torchao.dtypes.uintx import TensorCoreTiledLayout + from torchao.dtypes.floatx import FloatxTensorCoreLayout - if isinstance(self._layout, TensorCoreTiledLayout): - # need to return to original shape if tensor was padded - # in preprocessing - # TODO: we could add an API for this if there are more use cases - # (e.g. dequant_post_process) in TensorImpl or Layout - for dim, dim_size in enumerate(self.shape): - dq = dq.narrow(dim, 0, dim_size) - return dq + if isinstance(self._layout, FloatxTensorCoreLayout): + int_data, scale = self.tensor_impl.get_plain() + return dequantize_affine_floatx( + int_data, + scale, + self._layout.ebits, + self._layout.mbits, + output_dtype=output_dtype, + ) + else: + data, scale, zero_point = self.tensor_impl.get_plain() + dq = dequantize_affine( + data, + self.block_size, + scale, + zero_point, + data.dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain, + output_dtype=output_dtype, + ) + from torchao.dtypes.uintx import TensorCoreTiledLayout + + if isinstance(self._layout, TensorCoreTiledLayout): + # need to return to original shape if tensor was padded + # in preprocessing + # TODO: we could add an API for this if there are more use cases + # (e.g. dequant_post_process) in TensorImpl or Layout + for dim, dim_size in enumerate(self.shape): + dq = dq.narrow(dim, 0, dim_size) + return dq def __tensor_flatten__(self): return ["tensor_impl"], [ @@ -379,6 +395,33 @@ def from_hp_to_floatx_static( f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static" ) + @classmethod + def from_hp_to_fpx( + cls, + input_float: torch.Tensor, + _layout: Layout, + ): + from torchao.dtypes.floatx import FloatxTensorCoreLayout + + assert isinstance( + _layout, FloatxTensorCoreLayout + ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + # per axis quantization, where axis = 1 + block_size = list(input_float.shape) + block_size[1] = 1 + + ebits, mbits = _layout.ebits, _layout.mbits + # Note: these ops are hardcoded to have per axis quantization (axis=1) right now + scale = choose_qparams_affine_floatx(input_float, ebits, mbits) + floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) + floatx_packed = _layout.post_process(floatx_unpacked) + + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) + return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) + @property def _layout(self) -> Layout: return self.tensor_impl._layout @@ -434,6 +477,8 @@ def _apply_fn_to_data(self, fn): to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static +# experimental will be merged in to floatx +to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx if TORCH_VERSION_AT_LEAST_2_5: # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 4bfaa3de9e..3f0a1ccd5c 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,9 +1,7 @@ from .float8_layout import Float8Layout from .floatx_tensor_core_layout import ( - FloatxTensor, FloatxTensorCoreLayout, from_scaled_tc_floatx, - to_affine_quantized_fpx, to_scaled_tc_floatx, ) @@ -12,6 +10,4 @@ "to_scaled_tc_floatx", "from_scaled_tc_floatx", "Float8Layout", - "to_affine_quantized_fpx", - "FloatxTensor", ] diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index 99d07fd4e0..0f67e9826e 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -11,7 +11,6 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, - get_tensor_impl_constructor, register_layout, ) from torchao.dtypes.utils import ( @@ -23,11 +22,6 @@ _floatx_unpacked_to_f32, _n_ones, ) -from torchao.quantization.quant_primitives import ( - choose_qparams_affine_floatx, - dequantize_affine_floatx, - quantize_affine_floatx, -) aten = torch.ops.aten _ONES_TABLE = [_n_ones(i) for i in range(8)] @@ -462,54 +456,6 @@ class FloatxTensorCoreLayout(Layout): mbits: int -class FloatxTensor(AffineQuantizedTensor): - """ - Floatx quantized tensor subclass which inherits AffineQuantizedTensor class. It uses floating-point format defined by ebits (exponent bits) and mbits (mantissa bits) and supports float1 - float7 tensor types. - For details about float8 tensor type, please refer to https://github.com/pytorch/ao/blob/main/torchao/dtypes/floatx/float8_layout.py. - - To see what happens during choose_qparams_and_quantize_affine_fpx, quantization and dequantization for floatx quantization, - please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py - and check the two quant primitive ops: choose_qparams_affine_floatx, quantize_affine_floatx and dequantize_affine_floatx. - """ - - def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: - if output_dtype is None: - output_dtype = self.dtype - int_data, scale = self.tensor_impl.get_plain() - return dequantize_affine_floatx( - int_data, - scale, - self._layout.ebits, - self._layout.mbits, - output_dtype=output_dtype, - ) - - @classmethod - def from_hp_to_floatx( - cls, - input_float: torch.Tensor, - _layout: Layout, - ): - assert isinstance( - _layout, FloatxTensorCoreLayout - ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" - original_shape = input_float.shape - input_float = _layout.pre_process(input_float) - # per axis quantization, where axis = 1 - block_size = list(input_float.shape) - block_size[1] = 1 - - ebits, mbits = _layout.ebits, _layout.mbits - # Note: these ops are hardcoded to have per axis quantization (axis=1) right now - scale = choose_qparams_affine_floatx(input_float, ebits, mbits) - floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) - floatx_packed = _layout.post_process(floatx_unpacked) - - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) - return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) - - @register_layout(FloatxTensorCoreLayout) class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), @@ -711,6 +657,3 @@ def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): out += bias return out.view(*act.shape[:-1], out_dim).to(act.dtype) - - -to_affine_quantized_fpx = FloatxTensor.from_hp_to_floatx From 6c3bc539155145de8b5dff02b68ddade0d4e67c5 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 24 Jan 2025 12:39:48 -0800 Subject: [PATCH 027/115] Update api_ref_dtypes docs (#1610) --- docs/source/api_ref_dtypes.rst | 33 ++++++++++++++--- torchao/dtypes/affine_quantized_tensor.py | 37 ++++++++++--------- torchao/dtypes/floatx/float8_layout.py | 6 +++ .../floatx/floatx_tensor_core_layout.py | 4 +- torchao/dtypes/nf4tensor.py | 4 +- torchao/dtypes/uintx/block_sparse_layout.py | 6 +++ .../uintx/cutlass_int4_packed_layout.py | 2 + torchao/dtypes/uintx/int4_cpu_layout.py | 7 ++-- torchao/dtypes/uintx/marlin_qqq_tensor.py | 6 ++- torchao/dtypes/uintx/marlin_sparse_layout.py | 11 ++++++ torchao/dtypes/uintx/semi_sparse_layout.py | 7 ++++ .../dtypes/uintx/tensor_core_tiled_layout.py | 10 ++--- torchao/dtypes/uintx/uintx_layout.py | 11 ++++++ torchao/dtypes/utils.py | 19 +++++++--- 14 files changed, 122 insertions(+), 41 deletions(-) diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index fbe680953e..26e1266c09 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -6,19 +6,42 @@ torchao.dtypes .. currentmodule:: torchao.dtypes +Layouts and Tensor Subclasses +----------------------------- +.. autosummary:: + :toctree: generated/ + :nosignatures: + + NF4Tensor + AffineQuantizedTensor + Layout + PlainLayout + SemiSparseLayout + TensorCoreTiledLayout + Float8Layout + FloatxTensor + FloatxTensorCoreLayout + MarlinSparseLayout + BlockSparseLayout + UintxLayout + MarlinQQQTensor + MarlinQQQLayout + Int4CPULayout + CutlassInt4PackedLayout + +Quantization techniques +----------------------- .. autosummary:: :toctree: generated/ :nosignatures: - to_nf4 to_affine_quantized_intx to_affine_quantized_intx_static + to_affine_quantized_fpx to_affine_quantized_floatx to_affine_quantized_floatx_static - to_affine_quantized_fpx - NF4Tensor - AffineQuantizedTensor - + to_marlinqqq_quantized_intx + to_nf4 .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation. diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e7aca34c5f..e3ac420de7 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -44,9 +44,8 @@ # Tensor Subclass Definition # ############################## class AffineQuantizedTensor(TorchAOBaseTensor): - """ - Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: - quantized_tensor = float_tensor / scale + zero_point + """Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: + quantized_tensor = float_tensor / scale + zero_point To see what happens during choose_qparams, quantization and dequantization for affine quantization, please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py @@ -56,21 +55,18 @@ class AffineQuantizedTensor(TorchAOBaseTensor): regardless of the internal representation's type or orientation. fields: - tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data, - e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device - and operator/kernel - block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam - e.g. when size is the same as the input tensor dimension, we are using per tensor quantization - shape (torch.Size): the shape for the original high precision Tensor - quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during - quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT - dtype: dtype for original high precision tensor, e.g. torch.float32 + - tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data, + e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device and operator/kernel + - block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + - shape (torch.Size): the shape for the original high precision Tensor + - quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + - quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization + default is ZeroPointDomain.INT + - dtype: dtype for original high precision tensor, e.g. torch.float32 """ @staticmethod @@ -207,6 +203,7 @@ def from_hp_to_intx( _layout: Layout = PlainLayout(), use_hqq: bool = False, ): + """Convert a high precision tensor to an integer affine quantized tensor.""" original_shape = input_float.shape input_float = _layout.pre_process(input_float) @@ -302,6 +299,7 @@ def from_hp_to_intx_static( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), ): + """Create an integer AffineQuantizedTensor from a high precision tensor using static parameters.""" if target_dtype not in FP8_TYPES: assert ( zero_point_domain is not None @@ -348,6 +346,7 @@ def from_hp_to_floatx( _layout: Layout, scale_dtype: Optional[torch.dtype] = None, ): + """Convert a high precision tensor to a float8 quantized tensor.""" if target_dtype in FP8_TYPES: return cls.from_hp_to_intx( input_float=input_float, @@ -378,6 +377,7 @@ def from_hp_to_floatx_static( target_dtype: torch.dtype, _layout: Layout, ): + """Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters.""" if target_dtype in FP8_TYPES: return cls.from_hp_to_intx_static( input_float=input_float, @@ -401,6 +401,7 @@ def from_hp_to_fpx( input_float: torch.Tensor, _layout: Layout, ): + """Create a floatx AffineQuantizedTensor from a high precision tensor. Floatx is represented as ebits and mbits, and supports the representation of float1-float7.""" from torchao.dtypes.floatx import FloatxTensorCoreLayout assert isinstance( diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index dd995fb157..5a7e1924b3 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -25,6 +25,12 @@ @dataclass(frozen=True) class Float8Layout(Layout): + """Represents the layout configuration for Float8 affine quantized tensors. + + Attributes: + mm_config (Optional[Float8MMConfig]): Configuration for matrix multiplication operations involving Float8 tensors. If None, default settings are used. + """ + mm_config: Optional[Float8MMConfig] = None diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index 0f67e9826e..beaa2e536e 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -450,7 +450,9 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> # quantization api integrations @dataclass(frozen=True) class FloatxTensorCoreLayout(Layout): - """Layout type for FloatxTensorCoreAQTTensorImpl""" + """FloatxTensorCoreLayout is a data class that defines the layout for a tensor with a specific number of exponent bits (ebits) and mantissa bits (mbits). + This layout is used in the context of quantization and packing of tensors optimized for TensorCore operations. + """ ebits: int mbits: int diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 14a8c2d43e..5ae06a1fe1 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -662,10 +662,9 @@ def dequantize_scalers( ) -> torch.Tensor: """Used to unpack the double quantized scalers - Args; + Args: input_tensor: Input tensor to convert to QLoRA format this is the quantized scalers in int8 format quantization_factor: Tensor of per_scaler_block quantization factors stored in inpt_weight.dtype - size: (n_scaler_blocks) scaler_block_size: Scaler block size to use for double quantization. """ @@ -953,6 +952,7 @@ def linear_nf4(input: torch.Tensor, weight: NF4Tensor) -> torch.Tensor: def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256): + """Convert a given tensor to normalized float 4-bit tensor.""" return NF4Tensor.from_tensor(tensor, block_size, scaler_block_size) diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 0670986b13..6681847608 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -27,6 +27,12 @@ @dataclass(frozen=True) class BlockSparseLayout(Layout): + """BlockSparseLayout is a data class that represents the layout of a block sparse matrix. + + Attributes: + blocksize (int): The size of the blocks in the sparse matrix. Default is 64. + """ + blocksize: int = 64 diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index a6412ec88c..9c0d0bb055 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -29,6 +29,8 @@ def _aqt_is_int4(aqt): @dataclass(frozen=True) class CutlassInt4PackedLayout(Layout): + """Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel.""" + pass diff --git a/torchao/dtypes/uintx/int4_cpu_layout.py b/torchao/dtypes/uintx/int4_cpu_layout.py index 7c734a8a44..d587591ccc 100644 --- a/torchao/dtypes/uintx/int4_cpu_layout.py +++ b/torchao/dtypes/uintx/int4_cpu_layout.py @@ -24,15 +24,16 @@ @dataclass(frozen=True) class Int4CPULayout(Layout): - """Only for PyTorch version at least 2.6""" + """Layout class for int4 CPU layout for affine quantized tensor, used by tinygemm kernels `_weight_int4pack_mm_for_cpu`. + Only for PyTorch version at least 2.6 + """ pass @register_layout(Int4CPULayout) class Int4CPUAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, + """TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, used by tinygemm kernels `_weight_int4pack_mm_for_cpu` It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of dimension: [n][k / 2] (uint8 dtype) diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index b75d959b41..3a4253bb3f 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -29,8 +29,7 @@ class MarlinQQQTensor(AffineQuantizedTensor): - """ - MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. + """MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py @@ -58,6 +57,7 @@ def from_hp_to_intx( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, _layout: Optional[Layout] = None, ): + """Converts a floating point tensor to a Marlin QQQ quantized tensor.""" original_shape = input_float.shape input_float = _layout.pre_process(input_float) nbits = int(math.log2(quant_max - quant_min + 1)) @@ -81,6 +81,8 @@ def from_hp_to_intx( @dataclass(frozen=True) class MarlinQQQLayout(Layout): + """MarlinQQQLayout is a layout class for Marlin QQQ quantization.""" + pass diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index 2a84dd1813..22763eb0c2 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -71,6 +71,17 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b @dataclass(frozen=True) class MarlinSparseLayout(Layout): + """MarlinSparseLayout is a layout class for handling sparse tensor formats + specifically designed for the Marlin sparse kernel. This layout is used + to optimize the storage and computation of affine quantized tensors with + 2:4 sparsity patterns. + + The layout ensures that the tensor data is pre-processed and stored in a + format that is compatible with the Marlin sparse kernel operations. It + provides methods for preprocessing input tensors and managing the layout + of quantized tensors. + """ + def pre_process(self, input: torch.Tensor) -> torch.Tensor: """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index a554fd9bc6..3c35a4d8cd 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -66,6 +66,13 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( @dataclass(frozen=True) class SemiSparseLayout(Layout): + """SemiSparseLayout is a layout class for handling semi-structured sparse + matrices in affine quantized tensors. This layout is specifically designed + to work with the 2:4 sparsity pattern, where two out of every four elements + are pruned to zero. This class provides methods for preprocessing input + tensors to conform to this sparsity pattern. + """ + def pre_process(self, input: torch.Tensor) -> torch.Tensor: # prune to 2:4 if not already temp = input.detach() diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 378744e7e1..b29c9d167b 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -91,9 +91,10 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): @dataclass(frozen=True) class TensorCoreTiledLayout(Layout): - """ - inner_k_tiles is an internal argument for packing function of tensor core tiled layout - that can affect the performance of the matmul kernel + """TensorCoreTiledLayout is a layout class for handling tensor core tiled layouts in affine quantized tensors. It provides methods for pre-processing and post-processing tensors to fit the required layout for efficient computation on tensor cores. + + Attributes: + inner_k_tiles (int): An internal argument for the packing function of tensor core tiled layout that can affect the performance of the matmul kernel. Defaults to 8. """ inner_k_tiles: int = 8 @@ -149,8 +150,7 @@ def extra_repr(self): @register_layout(TensorCoreTiledLayout) class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, + """TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, used by tinygemm kernels `_weight_int4pack_mm` It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of diff --git a/torchao/dtypes/uintx/uintx_layout.py b/torchao/dtypes/uintx/uintx_layout.py index 29c2ae93fe..ef85319cd5 100644 --- a/torchao/dtypes/uintx/uintx_layout.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -209,6 +209,17 @@ def _(func, types, args, kwargs): @dataclass(frozen=True) class UintxLayout(Layout): + """A layout class for Uintx tensors, which are tensors with elements packed into + smaller bit-widths than the standard 8-bit byte. This layout is used to define + how the data is stored and processed in UintxTensor objects. + + Attributes: + dtype (torch.dtype): The data type of the tensor elements, which determines + the bit-width used for packing. + pack_dim (int): The dimension along which the data is packed. Default is -1, + which indicates the last dimension. + """ + dtype: torch.dtype pack_dim: int = -1 diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 0952b2a4bf..45a0b4312d 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -27,6 +27,15 @@ @dataclass(frozen=True) class Layout: + """The Layout class serves as a base class for defining different data layouts for tensors. + It provides methods for pre-processing and post-processing tensors, as well as static + pre-processing with additional parameters like scale, zero_point, and block_size. + + The Layout class is designed to be extended by other layout classes that define specific + data representations and behaviors for tensors. It is used in conjunction with TensorImpl + classes to represent custom data layouts and how tensors interact with different operators. + """ + def pre_process(self, input: torch.Tensor) -> torch.Tensor: return input @@ -49,13 +58,13 @@ def extra_repr(self) -> str: return "" -""" -Plain Layout, the most basic Layout, also has no extra metadata, will typically be the default -""" - - @dataclass(frozen=True) class PlainLayout(Layout): + """PlainLayout is the most basic layout class, inheriting from the Layout base class. + It does not add any additional metadata or processing steps to the tensor. + Typically, this layout is used as the default when no specific layout is required. + """ + pass From 860da263936aedc153283210f2f86573830625dd Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 24 Jan 2025 15:52:22 -0500 Subject: [PATCH 028/115] Add module swap -> tensor subclass migration tutorial (#1596) Adds a migration tutorial from module swap to tensor subclass for expressing basic quantization. This is a simplified version of the existing subclass tutorials in torchao, removing layers of indirection like Layout and TensorImpl for ease of understanding. This commit also removes overlapping content from the existing contributor guide. Work was done with @bdhirsh. --- docs/source/contributor_guide.rst | 216 +-------- docs/source/index.rst | 2 + docs/source/subclass_advanced.rst | 4 + docs/source/subclass_basic.rst | 462 ++++++++++++++++++++ tutorials/examples/logging_subclass.py | 66 +++ tutorials/examples/quantized_module_swap.py | 72 +++ tutorials/examples/quantized_subclass.py | 183 ++++++++ 7 files changed, 790 insertions(+), 215 deletions(-) create mode 100644 docs/source/subclass_advanced.rst create mode 100644 docs/source/subclass_basic.rst create mode 100644 tutorials/examples/logging_subclass.py create mode 100644 tutorials/examples/quantized_module_swap.py create mode 100644 tutorials/examples/quantized_subclass.py diff --git a/docs/source/contributor_guide.rst b/docs/source/contributor_guide.rst index e76b9420d0..7d4d20cc65 100644 --- a/docs/source/contributor_guide.rst +++ b/docs/source/contributor_guide.rst @@ -125,7 +125,7 @@ On the top of the stack will be the final quantization algorithms and quantizati For demonstration purposes, let's say after previous step we have ``AffineQuantizedTensor`` and ``to_affine_quantized`` factory function defined. For simplicity, let's say ``to_affine_quantized`` takes a high precision floating point Tensor and a target_dtype (e.g. torch.int8) and converts it to an ``AffineQuantizedTensor`` with corresponding dtype. -Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in ``Tensor Subclass Developer Guide`` section. +Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in the `Writing Your Own Tensor Subclass `__ tutorial. Weight Only Quantization ######################## @@ -257,220 +257,6 @@ During Save/Load Since ``AffineQuantizedTensor`` weight is still a ``torch.Tensor``, save/load works the same way as the original high precision floating point model. See the `serialization doc `__ for more details. -Tensor Subclass Developer Guide -=============================== - -We have covered high level overview and how everything is connected together in the previous section, this section will focus on Tensor Subclasses, which is the main extension point we rely on to provide flexibility of supporting inference, training and fine tuning with low precision Tensors and composability with torch.compile, autograd, distributed primitives in these scenarios. - -Prerequisites -~~~~~~~~~~~~~ -Some externally available resources for tensor subclasses: - -* `tensor subclass doc `__ -* `Edward's podcast about tensor subclasses `__ -* `Tensor subclass zoo `__ - -Why Tensor Subclass? -~~~~~~~~~~~~~~~~~~~~ -There are multiple ways people can implement quantization techniques or new dtypes, main motivation for us to recommend the tensor subclass based approach are three things: -(1). It’s natural for quantization to be modeled as a dtype conversion, so implementing it with tensor subclass means we are not introducing new concepts but reusing existing concepts like dtype, layout that already exists in pytorch core -(2). Since tensor subclass intercepts computation at torch function or aten ops level, as long as the same function/operator is used, we will be able to quantize the model. This allows the model that’s using variants of native modules (e.g. a slightly modified version of nn.Linear) to still be compatible with quantization -(3). Tensor subclass is also the approach adopted by other techniques like sparsity and distributed, so implementing quantization or dtype conversion with tensor subclass would make it easier for it to be composable with these techniques - -Example Code for a new DType -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Please feel free to start with `tutorial `__ for a end to end working example that combines everything we talked about together and come back to the doc for clarifications and documentations. - -Basic Structure -~~~~~~~~~~~~~~~ -A tensor subclass needs to define a few basic methods: ``__new__``, ``__init__``, ``__tensor_flatten__``, ``__tensor_unflatten__`` -and also dispatch functions for torch functions ``__torch_function__`` and aten ops ``__torch_dispatch__``. - -Here is an example of basic structure:: - # check out docs in https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L437 - from torchao.utils import TorchAOBaseTensor - - class MyDTypeLayout(TorchAOBaseTensor): - # see tutorial code for details - pass - - class MyDtypeTensor(TorchAOBaseTensor): - """We need to define `__new__` for constructing a new tensor subclass instance and `__init__` for initialize - the instance. There is no requirement on what the argument list should look like here, only requirement is - that `__new__` must return a Tensor instance with `torch.Tensor._make_wrapper_subclass(cls, shape, ...)` call - """ - @staticmethod - def __new__( - cls, - tensor_impl: MyDTypeLayout, - shape: torch.Size, - dtype: Optional[torch.dtype] = None, - ): - ... - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - tensor_impl: MyDTypeLayout, - shape: torch.Size, ... - ): - self.tensor_impl = tensor_impl - - - """`__tensor_flatten__` and `__tensor_unflatten__` are used to desugar the tensor into native Tensors/attributes and - reconstruct the tensor subclass instance from the desugared tensor and attributes, these are required to define - a Tensor subclass for torch.compile support - """ - def __tensor_flatten__(self): - return ["tensor_impl"], [self.shape] - - """see https://github.com/pytorch/pytorch/blob/3bc2004f9123a32f381ef64202252d59109507f3/torch/utils/_python_dispatch.py#L289 for documentations for outer_size and outer_stride - """ - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - tensor_impl = tensor_data_dict["tensor_impl"] - shape, = tensor_attributes - return cls( - tensor_impl, - shape if outer_size is None else outer_size, - ) - - - """classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype - """ - @classmethod - def from_float( - cls, - input_float: torch.Tensor, - ): - mapping_type = MappingType.SYMMETRIC - block_size = input_float.shape - dtype = torch.int16 - scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) - int_data = (input_float / scale).to(torch.int8) - tensor_impl = MyDTypeLayout.from_plain(int_data, scale) - return cls(tensor_impl, input_float.shape) - - - """[Optional] see docs for `Layout/Packing` under `Quantized Tensors` section to understand what layout_type is - """ - @property - def _layout(self) -> LayoutType: - return self.tensor_impl._layout - - """There are two entry points that we can modify the behavior of a pytorch op: torch_function and torch_dispatch: - - __torch_function__: will be called whenever a torch level function is called on the Tensor object, for example: torch.nn.functional.linear, - tensor.detach, tensor.reshape, tensor.t etc. - - __torch_dispatch__: will be called in the C++ dispatcher, when an aten operator is called on the Tensor object, for example: - aten.mm, aten.addmm, aten.detach.default, aten.t.default etc. - you can checkout https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L361-L389 to understand what `__torch_function__` and `__torch_dispatch__` are doing, but with `TorchAoBaseTensor` user can use - some helper functions directly (see next section) - -Operator Support -~~~~~~~~~~~~~~~~ -There are two types of operator support, torch function and aten ops. For torch functions (e.g. ``torch.nn.functional.linear``), we’ll need to overwrite ``__torch_function__`` callback in the Tensor subclass, for aten ops (e.g. ``torch.ops.aten.mm``), we’ll need to overwrite ``__torch_dispatch__`` callback function. - -For a new dtype, we’d like people to define the following decorator:: - if your dtype class is inherited from `torchao.utils.TorchAoBaseTensor`, you can do: - - implements = my_dtype_tensor_cls.implements - -And we can implement the operator dispatch with the following:: - # Example for torch_function dispatch for torch.nn.functional.linear - def _quantized_linear_op(input_tensor, weight_tensor, bias): - if isinstance(input_tensor, MyDtypeTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, MyDtypeTensor): - weight_tensor = weight_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - - - @implements(torch.nn.functional.linear) - def _(*args, **kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - # using try/except here so that we can have a general fallback when input_tensor/weight_tensor - # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to - # make the branches easier to understand in `_quantized_linear_op` - try: - return _quantized_linear_op(input_tensor, weight_tensor, bias) - except NotImplementedError: - if isinstance(input_tensor, MyDtypeTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, MyDtypeTensor): - weight_tensor = weight_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - - # Example for aten op dispatch for aten.detach.default - @implements(aten.detach.default) - def _(func, *args, **kwargs): - # `return_and_correct_aliasing` should be used by wrapper tensor ``__torch_dispatch__`` subclasses that would like to - # work with torch.compile. It ensures that the subclass properly implements the aliasing behavior of every op, - # which is needed for correctness in AOTAutograd. - - # `_apply_fn_to_data` just applies the function to the tensor data in `args[0]`, `args[0]` is a tensor subclass - # of `my_dtype` - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - -What ops do we need to overwrite? This depends on the model we are trying to quantize, commonly overwritten ops are: -``__torch_function__``: ``torch.nn.functional.linear`` -``__torch_dispatch__``: ``torch.ops.aten.addmm.default``, ``torch.ops.aten.mm.default``, ``torch.ops.aten.detach.default``, ``torch.ops.aten.t.default`` - -You can also find the ops that can be overwritten in ``__torch_function__`` or ``__torch_dispatch__`` with the following code, and you can start with a model that you want to optimize, start with just overwriting the important ops like linear, and gradually expand the coverage until the test runs and you get the expected optimized generated code (see Optimized Operators section for more details):: - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = torch.nn.Linear(10, 10) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) + x - - from torch.overrides import TorchFunctionMode - class TorchFunctionLoggingMode(TorchFunctionMode): - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - print(f"TORCH_FUNC={str(func)}") - return func(*args, **kwargs) - - with TorchFunctionLoggingMode(): - m(*example_inputs) - - ## Example output - # TORCH_FUNC= - # TORCH_FUNC= - - - from torch.utils._python_dispatch import TorchDispatchMode - class TorchDispatchLoggingMode(TorchDispatchMode): - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - print(f"ATEN_FUNC={str(func)}") - return func(*args, **kwargs) - - with TorchDispatchLoggingMode(): - m(*example_inputs) - - ## Example output - # ATEN_FUNC=aten.t.default - # ATEN_FUNC=aten.addmm.default - # ATEN_FUNC=aten.add.Tensor - - # or a more polished logging for torch_dispatch (aten) ops: https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py - -Alternatively, you can run a test example (e.g. use your quantized model with tensor parallelism, FSDP etc.) and discover the missing ops and add them until the test passes. - -We are still working on a table that talks about for each feature what are the operators that need to be supported. - Adding Efficient Kernels ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/index.rst b/docs/source/index.rst index 04a53ce454..f526c77939 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -37,3 +37,5 @@ for an overall introduction to the library and recent highlight and updates. :caption: Tutorials serialization + subclass_basic + subclass_advanced diff --git a/docs/source/subclass_advanced.rst b/docs/source/subclass_advanced.rst new file mode 100644 index 0000000000..f2df5a1cf0 --- /dev/null +++ b/docs/source/subclass_advanced.rst @@ -0,0 +1,4 @@ +Writing Your Own Quantized Tensor (advanced) +-------------------------------------------- + +Coming soon! diff --git a/docs/source/subclass_basic.rst b/docs/source/subclass_basic.rst new file mode 100644 index 0000000000..e007ea5bab --- /dev/null +++ b/docs/source/subclass_basic.rst @@ -0,0 +1,462 @@ +Writing Your Own Quantized Tensor +--------------------------------- + +Quantization in torchao is built on the foundation of tensor subclasses. +They are the main extension point for torchao to provide flexible +inference and training support using low precision computation, while +composing with important PyTorch features such as torch.compile, +autograd, and distributed primitives. + +In this tutorial, we will highlight the benefits of leveraging tensor +subclasses compared to module swaps, and walk through a simple example +of how to express quantization using this approach. + +What are Tensor Subclasses? +=========================== + +Tensor subclasses are simply classes that inherit from `torch.Tensor `__. +They allow users to interpose their custom computation logic between existing +ops in their models, such that functions in the top-level torch +namespace like torch.add will continue to work seamlessly. + +An obvious alternative to the tensor subclass approach is module swaps: +simply swap all nn.Linear modules in your model with your custom +Int8QuantizedLinear modules, for example. There are a few important +benefits of using tensor subclasses compared to this approach: + +1. **Finer-grained integration point.** Module swaps intercept + computation at the module level and so will not work for models that + rely on torch functions or variants of native modules (e.g. slightly + modified versions of nn.Linear). In contrast, since tensor subclasses + intercept computation at the function/op level, we will be able to + quantize the model as long as the same function/op is used. + +2. **Better composability.** Composing multiple features using module + swaps is clunky. For example, combining two existing + Int8QuantizedLinear and DistributedLinear modules would require users + to create another linear class that duplicates these functionalities. + Tensor subclasses bypass this problem by simply wrapping one subclass + in another. This can also offer performance benefits if the outer + tensor (e.g. `DTensor `__) + is aware that the inner tensor is quantized, and so can perform + expensive allgather operations using less network and memory + bandwidth. + +3. **Reusing PyTorch components.** It is natural to express quantization + using tensor subclasses since the quantized tensors are simply + torch.Tensors with different dtypes. The model structure does not + change (nn.Linears stay as nn.Linears), and so subsequent + optimization passes can also stay exactly the same as before. + +| +In the rest of the tutorial, we will walk through an example of how to +express quantization using both approaches. For further reading on +tensor subclasses, please refer to: + +- `Tensor subclass documentation `__ +- `Tensor subclass zoo `__ +- `Tensor subclass podcast by Edward Yang `__ + +Quantization with Module Swaps +============================== + +We begin with a simple example of how to implement int8 symmetric weight +only quantization using module swaps. All code can be found in this +`example script `__. +We will use the following function for quantizing float32 tensors into +int8 tensors: + +.. code:: py + + from typing import Tuple + import torch + + def int8_symmetric_quantize( + fp32_tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Symmetrically quantize the torch.float32 tensor into torch.int8. + Return a 2-tuple of (quantized value, scale). + + input: dimensions=[M, N], dtype=torch.float32 + output: dimensions=[M, N], dtype=torch.int8 + scale: dimensions=[M, 1], dtype=torch.float32 + """ + quant_min = -128 + quant_max = 127 + min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False) + max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = scale.view(fp32_tensor.shape[0], -1) + out = torch.round(fp32_tensor * (1.0 / scale)) + out = torch.clamp(out, quant_min, quant_max).to(torch.int8) + return out, scale + +Next, we will create a new QuantizedLinear module that calls this +function to dynamically quantize the weights: + +.. code:: py + + class QuantizedLinear(torch.nn.Linear): + """ + Linear module that performs dynamic and symmetric weight-only + int8 quantization. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + w_int8, scale = int8_symmetric_quantize(self.weight) + return torch.matmul(x, w_int8.t().to(x.dtype)) * scale.t() + + @classmethod + def from_float(cls, mod: torch.nn.Linear): + new_linear = cls(mod.in_features, mod.out_features, mod.bias) + new_linear.weight = mod.weight + return new_linear + +Then, the only thing that’s left is to swap all `nn.Linear` modules in the +model with our new QuantizedLinear. Let’s use the following toy model +for demonstration purposes: + +.. code:: py + + import copy + + class ToyModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + float_model = ToyModel(64, 128, 32).cuda() + quantized_model = copy.deepcopy(float_model) + + # Swap torch.nn.Linear with QuantizedLinear + for name, child in quantized_model.named_children(): + if type(child) == torch.nn.Linear: + new_linear = QuantizedLinear.from_float(child) + setattr(quantized_model, name, new_linear) + +Verify that the model now uses our QuantizedLinear module. This model is +now ready to use! + +.. code:: py + + >>> print(float_model) + ToyModel( + (linear1): Linear(in_features=64, out_features=128, bias=False) + (linear2): Linear(in_features=128, out_features=32, bias=False) + ) + + >>> print(quantized_model) + ToyModel( + (linear1): QuantizedLinear(in_features=64, out_features=128, bias=False) + (linear2): QuantizedLinear(in_features=128, out_features=32, bias=False) + ) + +An important drawback of this simple approach is flexibility. Currently +this only works for native PyTorch modules, but what if the model has +slightly modified linear modules that, for example, support distributed +training? It also won’t work with models that directly call the functional +version of linear (`torch.nn.functional.linear`) instead. + +Further, suppose we want to compose this feature with distribution, +which is also implemented through module swaps. There is no clean way to +do this except to create yet another module that combines both features. +These limitations can be solved with tensor subclasses, which is a more +elegant way to interpose custom computation such as quantization in your +model. + +Quantization with Tensor Subclasses +=================================== + +Here we are going to re-implement the above quantization technique, +using a `__torch_dispatch__`-based tensor subclass. + +Tensor subclasses (which often utilize `__torch_dispatch__`) are a pretty +powerful/flexible extension point in pytorch. They serve two main +purposes as an extension point: + +1) Tensor subclasses allow you to override the **implementation** of + (almost) every PyTorch API, and are used quite a bit to implement + other PyTorch offerings +2) Tensor subclasses allow you to **couple** your tensor data with + additional metadata. A few examples + + 1) [distributed] metadata on how a tensor is sharded across ranks + (`DTensor `__, + `docs `__) + 2) [quantization] scale/zero_point metadata + (`AffineQuantizedTensor `__) + 3) [raggedness] metadata on ragged structure + (`NestedTensor `__, + `docs `__) + +Some other resources on tensor subclasses for those who are interested: + +1) \__torch_dispatch_\_ docs + (`link `__) +2) What (and why) is \__torch_dispatch_\_ + (`link `__) +3) Google collab that implements a FlopCounter and MemoryTracker using + \__torch_dispatch_\_ + (`link `__) + +With that out of the way, let’s start by defining our bare-bones tensor +subclass for symmetric quantization: + +.. code:: py + + class Int8SymmetricTensor(torch.Tensor): + """ + Our subclass represents a tensor that has been quantized to int8 + It will hold two inner tensors: + int_data: int8[M, N] + scale: fp32[M, 1] + """ + + @staticmethod + @torch._dynamo.disable + def __new__(cls, int_data: torch.Tensor, scale: torch.Tensor): + return torch.Tensor._make_wrapper_subclass( + cls, + int_data.shape, + strides=int_data.stride(), + storage_offset=int_data.storage_offset(), + dtype=scale.dtype, + device=int_data.device, + ) + + @torch._dynamo.disable + def __init__(self, int_data: torch.Tensor, scale: torch.Tensor): + # inner data expected to be quantized already + assert int_data.dtype is torch.int8 + # we could do more work to support ndim > 2! + assert int_data.ndim == 2 + assert scale.ndim == 2 + self.int_data = int_data + self.scale = scale + + def __tensor_flatten__(self) -> Tuple[List[str], Any]: + """ + Returns a tuple of: + names of all inner tensor attributes (two in our case) + any other additional, non-tensor metadata. + + Needed for PT2 support. + """ + return ["int_data", "scale"], None + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None): + """ + __tensor_unflatten__ should effectively undo __tensor_flatten__. + + inputs: + a dict mapping names of inner tensor attributes back to the tensors + the constant metadata from __tensor_flatten__ + output: + a new instance of your subclass + + Needed for PT2 support. + """ + assert extra_metadata is None + int_data = tensor_data_dict["int_data"] + scale = tensor_data_dict["scale"] + return Int8SymmetricTensor(int_data, scale) + + def __repr__(self): + return f'Int8SymmetricTensor(int_data={repr(self.int_data)}, scale={repr(self.scale)})' + + @staticmethod + def from_float(float_tensor): + """ + Actually performs the symmetric quantization. + In our simple inference example we will quantize weights "ahead-of-time", + although later in a training example we can quantize/dequantize + during model execution, inside of our __torch_dispatch__ + + input: + float32 torch.Tensor + output: + Int8SymmetricTensor + """ + int8_tensor, scale = int8_symmetric_quantize(float_tensor) + return Int8SymmetricTensor(int8_tensor, scale) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + """ + Called for each ATen operator that our subclass is passed as an input to. + We need to define our own implementation for every operator here. + """ + if kwargs is None: + kwargs = {} + if func not in op_implementations_dict: + raise AssertionError(f'Int8SymmetricTensor does not yet support op: {str(func)}') + return op_implementations_dict[func](func, *args, **kwargs) + + + # Convenience function for registering our own implementation + # to every ATen operator in PyTorch + op_implementations_dict = {} + def register_op(ops: List[torch._ops.OpOverload]): + def impl_decorator(op_impl): + global op_implementations_dict + for op in ops: + op_implementations_dict[op] = op_impl + return op_impl + + return impl_decorator + +In the above code, we have done a few things: + +1) Defined a basic “wrapper” tensor subclass - it is effectively a + container object, that holds some inner data (in particular, two + tensors that correspond to our int8 data and scales) +2) Defined a `__torch_dispatch__` implementation, which will be called + for every ATen operator our model calls on any of our subclass inputs +3) (For PT2 support) Defined a `__tensor_flatten__`/`__tensor_unflatten__` + method. This is the largest of a few requirements we have in order for + our subclass to work with torch.compile (more on this later). It + effectively tells `torch.compile` how to “desugar” our subclass into + its inner components. +4) (For PT2 support) Added a `torch._dynamo.disable` decorator to both + constructor methods (`__new__` and `__init__`) (more on this later). + +Which operators should we implement? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +PyTorch has a pretty large operator surface. Instead of trying to give +our new tensor subclass 100% coverage, let’s just focus on the ops we +need for our toy model above. + +Which operators are called in our model though, so we know what to +implement first? The brute force way is to repeatedly run the model +to see what ops error in your subclass. A more elegant way is to log +every operator that your model sees during execution. This can be +achieved through another `LoggingTensor` subclass as in `this example `__. + +Let's implement the necessary ops below: + +.. code:: py + + from torch.utils._python_dispatch import return_and_correct_aliasing + + @register_op([torch.ops.aten.mm.default]) + def int8_mm(func, x, weight): + assert isinstance(weight, Int8SymmetricTensor), "Int8SymmetricTensor: matmul currently only supports the weight in low precision, not the input!" + return torch.mm(x, weight.int_data.to(x.dtype)) * weight.scale + + @register_op([ + torch.ops.aten.detach.default, + torch.ops.aten.t.default, + ]) + def int8_view_ops(func, *args, **kwargs): + assert isinstance(args[0], Int8SymmetricTensor) + out_data = func(args[0].int_data, *args[1:], **kwargs) + out_scale = func(args[0].scale, *args[1:], **kwargs) + out = Int8SymmetricTensor(out_data, out_scale) + return return_and_correct_aliasing(func, args, kwargs, out) + +One thing you’ll notice quickly is: our model itself consists of a few +linear layers, but we see a few operations like `aten.t` and `aten.mm` +hitting our subclass. Some background: + +- We have a number of op decompositions that live in C++, that run + “above” tensor subclasses. `linear` is one such op (the decomp + lives `here `__) +- Decompositions can be good in the sense that they shrink the size of + the API that you as a subclass author have to implement. But they can + be painful if you would rather override the “higher level” operator + than the underlying operations in its decomposition. +- If you would prefer to override some operations (like Linear) at a + higher level, you can do so using `__torch_function__` + (`example `__). + It’s worth noting that if you want autograd support, then any + overrides you perform at the `__torch_function__` layer need to be + written in a way that is differentiable, while any overrides you + perform in `__torch_dispatch__` will be automatically differentiable. + +There are a few nuances in our implementations worth pointing out: + +1) You’ll notice that we no longer had to transpose our weight / scales + inside of our mm implementation. That’s because the transposition + “already happened” before we got to the `aten.mm` op. +2) Our `aten.mm` implementation does **not** return a tensor subclass + output. In that sense, the “propagation” of our quantized subclass + ends with matmuls. This maps to the fact that our weights are in low + precision, but we need to perform the matmuls themselves in high + precision. In general, subclass authors are free to choose for which + ops their subclasses do-or-do-not propagate. If you wanted every + function in your model to be quantized (including all pointwise and + reduction operations), you could write your subclass implementation + to quantize the output of every op and always return a subclass. +3) We were able to re-use the same implementation for 4 view operations. + In general, many ops might work with a pretty generic implementation: + unwrap any subclass inputs, run the underlying operator on the inner + tensor, and wrap the output back into a subclass. + + - Whether you can always re-use an implementation, though, depends + on what you are trying to do. For example, we implemented + `transpose(dim0, dim1)` on our subclass by calling the same + transpose on our inner data and inner scale tensor. This wouldn’t + work if our scale and data tensors had a different number of + dimensions, so transposition in that case would require a custom + implementation. + + +Comparing the Outputs +===================== + +And with all of that out of the way, let’s run our model with both +versions of quantization and confirm that they give the same output! + +.. code:: py + + float_model = ToyModel(64, 128, 32).cuda() + quantized_model_module_swap = copy.deepcopy(float_model) + quantized_model_subclass = copy.deepcopy(float_model) + + # Swap torch.nn.Linear with QuantizedLinear + for name, child in quantized_model_module_swap.named_children(): + if type(child) == torch.nn.Linear: + new_linear = QuantizedLinear.from_float(child) + setattr(quantized_model_module_swap, name, new_linear) + + # Swap torch.nn.Linear weights with Int8SymmetricTensor subclasses + for name, child in quantized_model_subclass.named_children(): + if type(child) == torch.nn.Linear: + subclass_param = Int8SymmetricTensor.from_float(child.weight) + child.weight = torch.nn.Parameter(subclass_param, requires_grad=True) + + with torch.no_grad(): + x = torch.randn(64, 64, 64, device='cuda') + out_module_swap = quantized_model_module_swap(x) + out = quantized_model_subclass(x) + print(torch.allclose(out, out_module_swap)) # prints True + + # We can also use torch.compile to fuse some of our quantized logic + out_compiled = torch.compile(quantized_model_subclass)(x) + print(torch.allclose(out, out_compiled)) # prints True + + +Next Steps +========== + +In this tutorial, we demonstrated how to build a simple quantized tensor +subclass. This is part one of two tutorials in this series. The +`next post `__ will discuss how to add more advanced +features to your tensor subclass, such as making it trainable, composing +with DTensors, and adding tensor parallelism support. For a more detailed +example of how `AffineQuantizedTensor` in torchao was built using tensor +subclasses, also check out `this example `__. + +If you have any questions while implementing your subclass, feel free to +file an issue `here `__. diff --git a/tutorials/examples/logging_subclass.py b/tutorials/examples/logging_subclass.py new file mode 100644 index 0000000000..ded50c56d6 --- /dev/null +++ b/tutorials/examples/logging_subclass.py @@ -0,0 +1,66 @@ +import torch +import torch.utils._pytree as pytree + + +class LoggingTensor(torch.Tensor): + @staticmethod + def __new__(cls, a): + return torch.Tensor._make_wrapper_subclass( + cls, + a.shape, + strides=a.stride(), + storage_offset=a.storage_offset(), + dtype=a.dtype, + device=a.device, + ) + + def __init__(self, a): + self.a = a + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + print("func: " + str(func)) + # Our logging subclass trivially implements *every* pytorch op. + # It does so by: + # - unwrapping any LoggingTensor arguments + # - calling the underlying function on the inner tensors + # - wrapping any tensor outputs into LoggingTensors + args_a = pytree.tree_map_only(LoggingTensor, lambda x: x.a, args) + kwargs_a = pytree.tree_map_only(LoggingTensor, lambda x: x.a, kwargs) + out_a = func(*args_a, **kwargs_a) + out_a_flat, spec = pytree.tree_flatten(out_a) + out_flat = [ + cls(o_a) if isinstance(o_a, torch.Tensor) else o_a for o_a in out_a_flat + ] + return pytree.tree_unflatten(out_flat, spec) + + +class ToyModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +if __name__ == "__main__": + # Set up toy model + float_model = ToyModel(64, 128, 32).cuda() + + # Replace any linear layer weights with our LoggingTensor + for name, child in float_model.named_children(): + if type(child) == torch.nn.Linear: + child.weight = torch.nn.Parameter( + LoggingTensor(child.weight), requires_grad=True + ) + + # run the model + with torch.no_grad(): + x = torch.randn(64, 64, 64, device="cuda") + _ = float_model(x) diff --git a/tutorials/examples/quantized_module_swap.py b/tutorials/examples/quantized_module_swap.py new file mode 100644 index 0000000000..07281a5bca --- /dev/null +++ b/tutorials/examples/quantized_module_swap.py @@ -0,0 +1,72 @@ +from typing import Tuple + +import torch + + +def int8_symmetric_quantize( + fp32_tensor: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Symmetrically quantize the torch.float32 tensor into torch.int8. + Return a 2-tuple of (quantized value, scale). + + input: dimensions=[M, N], dtype=torch.float32 + output: dimensions=[M, N], dtype=torch.int8 + scale: dimensions=[M, 1], dtype=torch.float32 + """ + quant_min = -128 + quant_max = 127 + min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False) + max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = scale.view(fp32_tensor.shape[0], -1) + out = torch.round(fp32_tensor * (1.0 / scale)) + out = torch.clamp(out, quant_min, quant_max).to(torch.int8) + return out, scale + + +class QuantizedLinear(torch.nn.Linear): + """ + Linear module that performs dynamic and symmetric weight-only + int8 quantization. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + w_int8, scale = int8_symmetric_quantize(self.weight) + return torch.matmul(x, w_int8.t().to(x.dtype)) * scale.t() + + @classmethod + def from_float(cls, mod: torch.nn.Linear): + new_linear = cls(mod.in_features, mod.out_features, mod.bias) + new_linear.weight = mod.weight + return new_linear + + +class ToyModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +if __name__ == "__main__": + # Set up toy model + model = ToyModel(64, 128, 32).cuda() + example_inputs = torch.randn((1, 64), dtype=torch.float32, device="cuda") + + # Swap torch.nn.Linear with QuantizedLinear + for name, child in model.named_children(): + if type(child) == torch.nn.Linear: + new_linear = QuantizedLinear.from_float(child) + setattr(model, name, new_linear) + + print("quantized model: ", model) + print("output: ", model(example_inputs)) diff --git a/tutorials/examples/quantized_subclass.py b/tutorials/examples/quantized_subclass.py new file mode 100644 index 0000000000..e256068294 --- /dev/null +++ b/tutorials/examples/quantized_subclass.py @@ -0,0 +1,183 @@ +import copy +from typing import Any, List, Tuple + +import torch + + +def int8_symmetric_quantize( + fp32_tensor: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Symmetrically quantize the torch.float32 tensor into torch.int8. + Return a 2-tuple of (quantized value, scale). + """ + quant_min = -128 + quant_max = 127 + min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False) + max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = scale.view(fp32_tensor.shape[0], -1) + out = torch.round(fp32_tensor * (1.0 / scale)) + out = torch.clamp(out, quant_min, quant_max).to(torch.int8) + return out, scale + + +# Our subclass represents a tensor that has been quantized to int8 +# It will hold two inner tensors: +# - int_data: int8[M, N] +# - scale: fp32[M, 1] +class Int8SymmetricTensor(torch.Tensor): + @staticmethod + @torch._dynamo.disable + def __new__(cls, int_data: torch.Tensor, scale: torch.Tensor): + return torch.Tensor._make_wrapper_subclass( + cls, + int_data.shape, + strides=int_data.stride(), + storage_offset=int_data.storage_offset(), + dtype=scale.dtype, + device=int_data.device, + ) + + @torch._dynamo.disable + def __init__(self, int_data: torch.Tensor, scale: torch.Tensor): + # inner data expected to be quantized already + assert int_data.dtype is torch.int8 + # we could do more work to support ndim > 2! + assert int_data.ndim == 2 + assert scale.ndim == 2 + self.int_data = int_data + self.scale = scale + + # __tensor_flatten__ returns a tuple of: + # - names of all inner tensor attributes (two in our case) + # - any other additional, non-tensor metadata. + def __tensor_flatten__(self) -> Tuple[List[str], Any]: + return ["int_data", "scale"], None + + # __tensor_unflatten__ should effectively undo __tensor_flatten__. + # inputs: + # - a dict mapping names of inner tensor attributes back to the tensors + # - the constant metadata from __tensor_flatten__ + # output: + # - a new instance of your subclass + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None + ): + assert extra_metadata is None + int_data = tensor_data_dict["int_data"] + scale = tensor_data_dict["scale"] + return Int8SymmetricTensor(int_data, scale) + + def __repr__(self): + return f"Int8SymmetricTensor(int_data={repr(self.int_data)}, scale={repr(self.scale)})" + + # Actually performs the symmetric quantization. + # In our simple inference example we will quantize weights "ahead-of-time", + # although later in a training example we can quantize/dequantize + # during model execution, inside of our __torch_dispatch__ + # input: + # - float32 torch.Tensor + # output: + # - Int8SymmetricTensor + @staticmethod + def from_float(float_tensor): + int8_tensor, scale = int8_symmetric_quantize(float_tensor) + return Int8SymmetricTensor(int8_tensor, scale) + + # __torch_dispatch__ gets called for ATen operator + # that our subclass is passed as an input to. + # We need to define our own implementation for every operator here. + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + if func not in op_implementations_dict: + raise AssertionError( + f"Int8SymmetricTensor does not yet support op: {str(func)}" + ) + return op_implementations_dict[func](func, *args, **kwargs) + + +# Convenience function for registering our own implementation +# to every ATen operator in PyTorch +op_implementations_dict = {} + + +def register_op(ops: List[torch._ops.OpOverload]): + def impl_decorator(op_impl): + global op_implementations_dict + for op in ops: + op_implementations_dict[op] = op_impl + return op_impl + + return impl_decorator + + +from torch.utils._python_dispatch import return_and_correct_aliasing + + +# matmul impl +@register_op([torch.ops.aten.mm.default]) +def int8_mm(func, x, weight): + assert isinstance( + weight, Int8SymmetricTensor + ), "Int8SymmetricTensor: matmul currently only supports the weight in low precision, not the input!" + return torch.mm(x, weight.int_data.to(x.dtype)) * weight.scale + + +# implementation of most view operations +@register_op( + [ + torch.ops.aten.detach.default, + torch.ops.aten.t.default, + torch.ops.aten.view.default, + torch.ops.aten._unsafe_view.default, + ] +) +def int8_view_ops(func, *args, **kwargs): + assert isinstance(args[0], Int8SymmetricTensor) + out_data = func(args[0].int_data, *args[1:], **kwargs) + out_scale = func(args[0].scale, *args[1:], **kwargs) + out = Int8SymmetricTensor(out_data, out_scale) + # "return_and_correct_aliasing" here is needed for torch.compile support. + # It effectively tells the compiler that the output of this view op aliases its input. + # At some point, we're hoping to infer this automatically and kill this extra API! + return return_and_correct_aliasing(func, args, kwargs, out) + + +class ToyModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +if __name__ == "__main__": + # Set up toy model + float_model = ToyModel(64, 128, 32).cuda() + quantized_model_subclass = copy.deepcopy(float_model) + + # Swap torch.nn.Linear weights with Int8SymmetricTensor subclasses + for name, child in quantized_model_subclass.named_children(): + if type(child) == torch.nn.Linear: + subclass_param = Int8SymmetricTensor.from_float(child.weight) + child.weight = torch.nn.Parameter(subclass_param, requires_grad=True) + + with torch.no_grad(): + x = torch.randn(64, 64, 64, device="cuda") + out = quantized_model_subclass(x) + + # We can also use torch.compile to fuse some of our quantized logic + # run with TORCH_LOGS="output_code" to see the generated inductor code + out_compiled = torch.compile(quantized_model_subclass)(x) + print(torch.allclose(out, out_compiled)) From 11440c2a7518977f58c25a0a47755dd692178bf3 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 24 Jan 2025 15:57:32 -0800 Subject: [PATCH 029/115] mx cleanup [1/x]: unbreak mx_formats tests (#1569) Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 8b6370a5cb..ead45cb8f4 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -26,6 +26,16 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) +# source: https://stackoverflow.com/a/22638709 +@pytest.fixture(autouse=True) +def run_around_tests(): + # 1. before test - set up (currently do nothing) + # 2. run test + yield + # 3. after test - teardown + torch._dynamo.reset() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("bias", [True, False]) From 6b472e5b62d11f2871dd3a65356b4bb1e9936861 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 24 Jan 2025 15:58:21 -0800 Subject: [PATCH 030/115] mx cleanup [2/x]: refactor mx gemm (#1593) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 31 ++++-- test/prototype/mx_formats/test_mx_tensor.py | 3 +- torchao/prototype/mx_formats/mx_linear.py | 101 +++++++++++++++----- torchao/prototype/mx_formats/mx_ops.py | 15 +-- torchao/prototype/mx_formats/mx_tensor.py | 7 ++ 5 files changed, 109 insertions(+), 48 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index ead45cb8f4..d280e38c36 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -39,7 +39,7 @@ def run_around_tests(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("input_shape", [(2, 4), (1, 2, 4), (1, 1, 2, 4)]) +@pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)]) def test_linear_eager(elem_dtype, bias, input_shape): """ Smoke test for training linear module with mx weight @@ -48,7 +48,7 @@ def test_linear_eager(elem_dtype, bias, input_shape): grad_shape[-1] = 6 m = nn.Sequential( - nn.Linear(4, 6, bias=bias, device="cuda"), + nn.Linear(8, 6, bias=bias, device="cuda"), ) m_mx = copy.deepcopy(m) block_size = 2 @@ -71,7 +71,7 @@ def test_linear_eager(elem_dtype, bias, input_shape): if elem_dtype is torch.float8_e4m3fn: assert y_sqnr >= 18.0 assert w_g_sqnr >= 18.0 - assert x_g_sqnr >= 14.0 + assert x_g_sqnr >= 12.0 else: assert y_sqnr >= 8.0 assert w_g_sqnr >= 10.0 @@ -101,28 +101,41 @@ def test_activation_checkpointing(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("bias", [False, True]) -def test_linear_compile(elem_dtype, bias): +# TODO(future PR): figure out why torch.compile does not match eager when +# autocast is on +@pytest.mark.parametrize( + "use_autocast", + [ + False, + ], +) +def test_linear_compile(elem_dtype, bias, use_autocast): """ Verify that compile does not change numerics of MX linear fw + bw """ if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): if not is_sm_at_least_89(): pytest.skip("CUDA capability >= 8.9 required for float8 in triton") - input_shape = (2, 4) - grad_shape = (2, 6) + M, K, N = 4, 8, 6 + input_shape = (M, K) + grad_shape = (M, N) m_mx = nn.Sequential( - nn.Linear(4, 6, bias=bias, device="cuda"), + nn.Linear(K, N, bias=bias, device="cuda"), ) block_size = 2 swap_linear_with_mx_linear(m_mx, elem_dtype, block_size) m_mx_c = copy.deepcopy(m_mx) - m_mx_c = torch.compile(m_mx_c, fullgraph=True) + m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() x = copy.deepcopy(x_ref) g = torch.randn(*grad_shape, device="cuda") - with torch.autocast("cuda", dtype=torch.bfloat16): + if use_autocast: + with torch.autocast("cuda", dtype=torch.bfloat16): + y_ref = m_mx(x_ref) + y = m_mx_c(x) + else: y_ref = m_mx(x_ref) y = m_mx_c(x) torch.testing.assert_close(y_ref, y, atol=0, rtol=0) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 02824f60d3..ae87ee021e 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -167,8 +167,9 @@ def test_transpose(elem_dtype, fp4_triton): if elem_dtype != DTYPE_FP4 and fp4_triton: pytest.skip("unsupported configuration") - tensor_hp = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16) + M, K = 128, 256 block_size = 32 + tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) config.use_fp4_custom_triton_dequant_kernel = fp4_triton tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t() diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index c429eb57d4..b69441e018 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -5,42 +5,81 @@ # LICENSE file in the root directory of this source tree. """ -Defines the UX for converting a model to use mx weights - -For now, this is a module swap for speed of iteration. - -Eventually we plan to move this to a tensor subclass weight wrapper for -inference, and to a tensor subclass weight wrapper + module hooks for training. +Defines the prototype UX for converting a model to use mx weights """ +from typing import Any + import torch import torch.nn.functional as F -from torchao.prototype.mx_formats.mx_tensor import MXTensor, to_mx +from torchao.prototype.mx_formats.mx_tensor import MXTensor @torch._dynamo.allow_in_graph -class NoopFwToMXBw(torch.autograd.Function): - """ - Forward: no-op - Backward: cast grad to MX - """ +class mx_mm(torch.autograd.Function): + # There are three gemms in a forward + backward of a Linear layer: + # + # 1. input @ weight_t = output (forward pass) + # 2. grad_output @ weight = grad_input (backward pass) + # 3. input_t @ grad_output = grad_weight (backward pass) @staticmethod - def forward(ctx, x, elem_dtype, block_size): + def forward( + ctx, + input_hp: torch.Tensor, + weight_hp: torch.Tensor, + elem_dtype: Any, + block_size: int, + ): + ctx.save_for_backward(input_hp, weight_hp) ctx.elem_dtype = elem_dtype ctx.block_size = block_size - return x + + # input @ weight_t = output + input_orig_shape = input_hp.shape + input_hp_r = input_hp.reshape(-1, input_orig_shape[-1]) + + input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, elem_dtype, block_size) + weight_mx_dim0 = MXTensor.to_mx(weight_hp, elem_dtype, block_size) + output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t()) + output = output.reshape(*input_orig_shape[:-1], output.shape[-1]) + + return output @staticmethod - def backward(ctx, g): - scale, data = to_mx(g, ctx.elem_dtype, ctx.block_size) - return ( - MXTensor(scale, data, ctx.elem_dtype, ctx.block_size, g.dtype), - None, - None, + def backward(ctx, grad_output_hp: torch.Tensor): + input_hp, weight_hp = ctx.saved_tensors + weight_hp_t_c = weight_hp.t().contiguous() + elem_dtype = ctx.elem_dtype + block_size = ctx.block_size + + grad_output_orig_shape = grad_output_hp.shape + grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1]) + + input_hp_orig_shape = input_hp.shape + input_hp_r = input_hp.reshape(-1, input_hp_orig_shape[-1]) + + # grad_output @ weight = grad_input + grad_output_mx_dim0 = MXTensor.to_mx(grad_output_hp_r, elem_dtype, block_size) + weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, elem_dtype, block_size) + grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) + grad_input = grad_input.reshape( + *grad_output_orig_shape[:-1], grad_input.shape[-1] ) + # input_t @ grad_output = grad_weight + grad_output_mx_dim1 = MXTensor.to_mx( + grad_output_hp_r.t().contiguous(), elem_dtype, block_size + ) + input_t_mx_dim0_tmp = MXTensor.to_mx( + input_hp_r.t().contiguous(), elem_dtype, block_size + ) + input_t_mx_dim0 = input_t_mx_dim0_tmp.t() + grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0) + + return grad_input, grad_weight, None, None + class MXLinear(torch.nn.Linear): """ @@ -59,16 +98,26 @@ def from_float(cls, mod, elem_dtype, block_size): return mod def forward(self, x): - x_mx = MXTensor.to_mx(x, self.elem_dtype, self.block_size) - w_mx = MXTensor.to_mx(self.weight, self.elem_dtype, self.block_size) - y = F.linear(x_mx, w_mx, self.bias) - y = NoopFwToMXBw.apply(y, self.elem_dtype, self.block_size) + if torch.is_autocast_enabled(): + # special case autocast + autocast_dtype = torch.get_autocast_dtype("cuda") + x = x.to(autocast_dtype) + w = self.weight.to(autocast_dtype) + else: + w = self.weight + + y = mx_mm.apply(x, w, self.elem_dtype, self.block_size) + if self.bias is not None: + y = y + self.bias return y class MXInferenceLinear(torch.nn.Linear): """ Inference version of MXLinear, with the weight pre-quantized to MX. + + Note: this is weight-only quantization, with the gemm being executed + in high precision. """ @classmethod @@ -84,8 +133,8 @@ def from_float(cls, mod, elem_dtype, block_size): # TODO(future PR): set to new_mod.weight directly, will need to work # through some errors new_mod.weight_mx = MXTensor.to_mx( - mod.weight.t().contiguous(), elem_dtype, block_size=block_size - ).t() + mod.weight, elem_dtype, block_size=block_size + ) new_mod.bias = mod.bias new_mod.elem_dtype = elem_dtype return new_mod diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 7a404b89a8..57fb0d54b4 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -65,22 +65,13 @@ def mx_mm(aten_op, args, kwargs=None): assert isinstance(a, MXTensor) and isinstance(b, MXTensor) a_hp = a.to_dtype(a._orig_dtype) b_hp = b.to_dtype(b._orig_dtype) + # assert memory layout we expect to be required in hardware + assert a_hp.is_contiguous() + assert b_hp.t().is_contiguous() res = aten_op(a_hp, b_hp) return res -@implements([aten.addmm.default]) -def mx_addmm(aten_op, args, kwargs=None): - a = args[0] - b = args[1] - c = args[2] - assert isinstance(b, MXTensor) and isinstance(c, MXTensor) - b_hp = b.to_dtype(b._orig_dtype) - c_hp = c.to_dtype(c._orig_dtype) - res = aten_op(a, b_hp, c_hp) - return res - - @implements([aten.t.default]) def mx_t(aten_op, args, kwargs=None): # For now, only transpose(input, 0, 1) is supported. diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 2e67f5a4ac..8eeeaf8bfd 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -314,6 +314,10 @@ def __new__( new_size = data_bits.size() if elem_dtype == DTYPE_FP4: # set the tensor size to what it would be without 2x4 packing + # Note: `is_contiguous` is going to return True for a tensor of size + # (M, 1) regardless or the order of dims, so this logic is currently + # broken for tensors of size (M, 1) or (1, M). Leaving broken until + # a time when fixing this becomes important. new_size = tensor_size_fp4x2_to_hp( new_size, data_bits.is_contiguous(), @@ -321,6 +325,9 @@ def __new__( self = torch.Tensor._make_wrapper_subclass( cls, new_size, + strides=data_bits.stride(), + storage_offset=data_bits.storage_offset(), + layout=data_bits.layout, dtype=orig_dtype, device=data_bits.device, ) From 47f96f12a4ffa9468f395c667269ca0fa8eef06d Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 24 Jan 2025 17:05:54 -0800 Subject: [PATCH 031/115] add separate quantization primitives for float8 (#1597) --- test/quantization/test_quant_primitives.py | 70 ++++++++++++++++++++++ torchao/quantization/quant_primitives.py | 67 +++++++++++++++++++++ 2 files changed, 137 insertions(+) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 102e76cb1a..77616c1c6a 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -9,16 +9,21 @@ import unittest import torch +from parameterized import parameterized from torchao.dtypes.utils import is_device +from torchao.float8.float8_utils import EPS as float8_eps from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, choose_qparams_affine, + choose_qparams_affine_float8, dequantize_affine, + dequantize_affine_float8, fake_quantize_affine, fake_quantize_affine_cachemask, quantize_affine, + quantize_affine_float8, ) # TODO: remove test for utils? @@ -838,6 +843,71 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + @parameterized.expand( + [ + ( + torch.float32, + torch.float8_e4m3fn, + ), + ( + torch.float32, + torch.float8_e5m2, + ), + ( + torch.bfloat16, + torch.float8_e4m3fn, + ), + ( + torch.bfloat16, + torch.float8_e5m2, + ), + ] + ) + def test_float8_quant_primitives(self, hp_dtype, float8_dtype): + input = torch.randn(10, 10) + + # float8 quantization primitives + scale = choose_qparams_affine_float8(input, float8_dtype=float8_dtype) + quantized = quantize_affine_float8(input, scale, float8_dtype=float8_dtype) + dequantized = dequantize_affine_float8(quantized, scale, output_dtype=hp_dtype) + + # reference implementation using generic primitives + expected_scale, _ = choose_qparams_affine( + input, + MappingType.SYMMETRIC, + input.shape, + float8_dtype, + eps=float8_eps, # use same EPS as float8 training + scale_dtype=torch.float32, + quant_min=torch.finfo(float8_dtype).min, + quant_max=torch.finfo(float8_dtype).max, + ) + expected_quantized = quantize_affine( + input, + input.shape, + scale, + output_dtype=float8_dtype, + quant_min=torch.finfo(float8_dtype).min, + quant_max=torch.finfo(float8_dtype).max, + zero_point=None, + zero_point_domain=None, + ) + expected_dequantized = dequantize_affine( + expected_quantized, + input.shape, + scale, + input_dtype=float8_dtype, + output_dtype=hp_dtype, + quant_min=torch.finfo(float8_dtype).min, + quant_max=torch.finfo(float8_dtype).max, + zero_point=None, + zero_point_domain=None, + ) + + self.assertTrue(torch.equal(expected_scale, scale)) + torch.testing.assert_close(expected_quantized, quantized) + torch.testing.assert_close(expected_dequantized, dequantized) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index e587d4bc2b..8b0ce28434 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -39,6 +39,9 @@ "MappingType", "ZeroPointDomain", "TorchAODType", + "choose_qparams_affine_float8", + "quantize_affine_float8", + "dequantize_affine_float8", ] @@ -1300,3 +1303,67 @@ def dequantize_affine_floatx( tensor = tensor * scale.float().view(-1, 1) tensor = tensor.to(dtype=output_dtype) return tensor + + +def choose_qparams_affine_float8( + tensor: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, +) -> torch.Tensor: + """ + Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. + + Args: + tensor (torch.Tensor): Input tensor to be quantized. + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). + """ + # only tensorwise scaling is supported for now: + quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max + min_val_neg = torch.min(tensor) + max_val_pos = torch.max(tensor) + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + return scale.to(dtype=torch.float32) + + +def quantize_affine_float8( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, +) -> torch.Tensor: + """ + Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. + + Args: + tensor (torch.Tensor): Input tensor to be quantized. + scale (torch.Tensor): Scaling factor for the quantization. + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). + """ + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically + # upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization. + # In order to match numerics between eager and compile, we upcast manually here. + tensor_scaled = tensor.to(torch.float32) / scale + max_value = torch.finfo(float8_dtype).max + tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) + fp8_tensor = tensor_clamped.to(float8_dtype) + return fp8_tensor + + +def dequantize_affine_float8( + tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Dequantizes the float8 tensor to high precision tensor. + + Args: + tensor (torch.Tensor): Input float8 tensor to be dequantized. + scale (torch.Tensor): Scaling factor for the dequantization. + output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32). + """ + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically + # upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization. + # In order to match numerics between eager and compile, we upcast manually here. + fp8_tensor = tensor.to(torch.float32) + hp_tensor = fp8_tensor * scale + return hp_tensor.to(output_dtype) From 09dd63677a071d88ffbf064f4b79130853768cef Mon Sep 17 00:00:00 2001 From: "Jane (Yuan) Xu" <31798555+janeyx99@users.noreply.github.com> Date: Mon, 27 Jan 2025 17:31:24 -0500 Subject: [PATCH 032/115] Prepare for -DPy_LIMITED_API flag in pytorch #145764 (#1627) * Prepare for enforcement of -DPy_LIMITED_API flag pytorch #145764 * Add the flag now to not regress * format --- setup.py | 16 ++++++++++------ torchao/csrc/cuda/fp6_llm/fp6_linear.cu | 4 ++-- .../s8s4_linear_cutlass/s8s4_linear_cutlass.cu | 2 +- .../tensor_core_tiled_layout.cu | 2 +- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index b657fa8df7..8628dc7ef4 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,8 @@ current_date = datetime.now().strftime("%Y%m%d") +PY3_9_HEXCODE = "0x03090000" + def get_git_commit_id(): try: @@ -212,24 +214,26 @@ def get_extensions(): extra_link_args = [] extra_compile_args = { + "cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"], "nvcc": [ "-O3" if not debug_mode else "-O0", "-t=0", - ] + ], } if not IS_WINDOWS: - extra_compile_args["cxx"] = [ - "-O3" if not debug_mode else "-O0", - "-fdiagnostics-color=always", - ] + extra_compile_args["cxx"].extend( + ["-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always"] + ) if debug_mode: extra_compile_args["cxx"].append("-g") extra_compile_args["nvcc"].append("-g") extra_link_args.extend(["-O0", "-g"]) else: - extra_compile_args["cxx"] = ["/O2" if not debug_mode else "/Od", "/permissive-"] + extra_compile_args["cxx"].extend( + ["/O2" if not debug_mode else "/Od", "/permissive-"] + ) if debug_mode: extra_compile_args["cxx"].append("/ZI") diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 6141dc3d74..cc601da34b 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -25,9 +25,9 @@ #include #include -#include #include #include +#include #include @@ -261,4 +261,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::quant_llm_linear", &fp_eXmY_linear_forward_cuda); } -} // namespace torchao \ No newline at end of file +} // namespace torchao diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu index 411343f0da..6253f8d5f7 100644 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu +++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu @@ -1,4 +1,4 @@ -#include +#include #include #include diff --git a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu index d3ddd66fe6..ea0f24c202 100644 --- a/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu +++ b/torchao/csrc/cuda/tensor_core_tiled_layout/tensor_core_tiled_layout.cu @@ -5,7 +5,7 @@ #include #include #include -#include +#include template constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { From 13bd59e1eada667d8bc616eaa8fdfeb882b740a3 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 27 Jan 2025 14:43:52 -0800 Subject: [PATCH 033/115] Update docs to refer to version.html (#1631) --- docs/source/_templates/layout.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index f1d3173de2..5f5bf020a5 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -2,7 +2,7 @@ {% block sidebartitle %} {% include "searchbox.html" %} {% endblock %} From e151d6a5288177a1a635c71fecd145654745af4c Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:14:04 -0500 Subject: [PATCH 034/115] notify when CI job fails (#1547) * test notify build_wheels_linux.yml * Update build_wheels_linux.yml * Update build_wheels_linux.yml * Update build_wheels_linux.yml * Update build_wheels_linux.yml * Update build_wheels_linux.yml * Update build_wheels_linux.yml * Update build_wheels_linux.yml * Update build_wheels_linux.yml * Update build_wheels_linux.yml * Update build_wheels_linux.yml * Update build_wheels_linux.yml * final commit * Update build_wheels_linux.yml * Update build_wheels_linux.yml * Update build_wheels_linux.yml --- .github/workflows/build_wheels_linux.yml | 35 ++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/.github/workflows/build_wheels_linux.yml b/.github/workflows/build_wheels_linux.yml index 3c37e0e1e0..8b966059f3 100644 --- a/.github/workflows/build_wheels_linux.yml +++ b/.github/workflows/build_wheels_linux.yml @@ -56,3 +56,38 @@ jobs: upload-to-pypi: cu121 secrets: PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + notify: + runs-on: ubuntu-latest + name: Email notification + needs: [generate-matrix, build] + if: failure() && github.event_name == 'schedule' + steps: + - uses: dawidd6/action-send-mail@v4 + with: + server_address: smtp.gmail.com + server_port: 465 + username: torchao.notify + password: ${{ secrets.TORCHAO_NOTIFY_PASSWORD }} + from: torchao.notify@gmail.com + to: ${{ secrets.TORCHAO_NOTIFY_RECIPIENT }} + subject: breakbutterflyScheduled Build Failure for TorchAO + body: | + Build Failure Notification for TorchAO + + A failure occurred in the Build Linux Wheels workflow. + + Run Details: + - Workflow: ${{ github.workflow }} + - Run Type: ${{ github.event_name }} + - Repository: ${{ github.repository }} + - Branch/PR: ${{ github.ref }} + - Commit: ${{ github.sha }} + + You can view the full run details here: + ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + + Error Information: + ${{ needs.generate-matrix.result == 'failure' && 'Matrix generation failed' || '' }} + ${{ needs.build.result == 'failure' && 'Build job failed' || '' }} + + This is an automated notification. Please check the GitHub Actions page for more details about the failure. From abd41e5f77cc5ab018094fdf3f8279111d2de320 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 28 Jan 2025 13:34:58 -0800 Subject: [PATCH 035/115] Add torchao/experimental CI test (#1586) * add torchao/experimental CI test * up * up * up * up * up * up --- .../workflows/torchao_experimental_test.yml | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 .github/workflows/torchao_experimental_test.yml diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml new file mode 100644 index 0000000000..c1419bccc6 --- /dev/null +++ b/.github/workflows/torchao_experimental_test.yml @@ -0,0 +1,42 @@ +name: Run TorchAO Experimental Tests + +on: + push: + branches: + - main + - 'gh/**' + pull_request: + branches: + - main + - 'gh/**' + +jobs: + test: + strategy: + matrix: + runner: [macos-14] + runs-on: ${{matrix.runner}} + defaults: + run: + shell: bash -el {0} + steps: + - name: Checkout repo + uses: actions/checkout@v3 + with: + submodules: true + - name: Setup environment + uses: conda-incubator/setup-miniconda@v3 + with: + python-version: "3.10" + miniconda-version: "latest" + activate-environment: venv + - name: Install requirements + run: | + conda activate venv + pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" torch=="2.6.0.dev20250104" + pip install numpy + USE_CPP=1 pip install . + - name: Run tests + run: | + conda activate venv + python torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py From 7b0d2ce50baaa2a137eb9d438a076544c43096a3 Mon Sep 17 00:00:00 2001 From: sanchitintel Date: Wed, 29 Jan 2025 00:13:58 -0800 Subject: [PATCH 036/115] Consolidate `ZeroPointDomain.NONE` & `None` zero point domains (#1556) * Fix ZeroPointDomain.NONE support & make it default for da8w8 weights * Fix bug & apply review recommendations * Throw exceptions when None zero_point_domain is used * Use ZeroPointDomain.NONE for weight in int8_dynamic_activation_int8_weight * Rebase with the latest main branch * Fix typo --- test/integration/test_integration.py | 47 ++++++++++-- test/quantization/test_observer.py | 17 +++-- test/quantization/test_quant_primitives.py | 53 ++++++++++++- torchao/dtypes/affine_quantized_tensor.py | 20 ++--- torchao/dtypes/uintx/marlin_qqq_tensor.py | 4 +- torchao/quantization/observer.py | 5 +- .../qat/affine_fake_quantized_tensor.py | 5 ++ torchao/quantization/qat/api.py | 2 + torchao/quantization/quant_api.py | 8 +- torchao/quantization/quant_primitives.py | 74 +++++++++++-------- 10 files changed, 171 insertions(+), 64 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index c926cee060..56bcaf17df 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -10,6 +10,7 @@ import logging import os import unittest +from functools import partial import torch import torch.nn as nn @@ -48,6 +49,7 @@ quantize_, ) from torchao.quantization.quant_primitives import ( + MappingType, dequantize_affine, ) from torchao.quantization.smoothquant import ( @@ -102,6 +104,8 @@ COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16] +ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC] + COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy() @@ -121,9 +125,18 @@ def _int8wo_groupwise_api(mod): quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False) -def _int8da_int8w_api(mod): +def _int8da_int8w_api( + mod, + act_mapping_type=MappingType.SYMMETRIC, +): if TORCH_VERSION_AT_LEAST_2_4: - quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False) + quantize_( + mod, + int8_dynamic_activation_int8_weight( + act_mapping_type=act_mapping_type, + ), + set_inductor_config=False, + ) if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(mod) else: @@ -962,10 +975,11 @@ def _test_lin_weight_subclass_api_impl( mod[0].weight.tensor_impl.get_plain() test = mod(x) + self.assertGreater( SQNR(ref_f, test), min_sqnr, - f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}", + f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}", ) mod_qc = torch.compile(mod, mode="max-autotune") @@ -973,14 +987,31 @@ def _test_lin_weight_subclass_api_impl( self.assertGreater( SQNR(ref_f, test_comp), min_sqnr, - f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}", + f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}", ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_int8_dynamic_quant_subclass_api(self, device, dtype): - self._test_lin_weight_subclass_api_impl( - _int8da_int8w_api, device, 35, test_dtype=dtype + @parameterized.expand( + list( + itertools.product( + COMMON_DEVICES, + COMMON_DTYPES, + ACT_MAPPING_TYPES, + ) + ) + ) + def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping): + if ( + not TORCH_VERSION_AT_LEAST_2_5 + and dtype in (torch.float16, torch.bfloat16) + and act_mapping is MappingType.ASYMMETRIC + and device == "cpu" + ): + self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5") + api = partial( + _int8da_int8w_api, + act_mapping_type=act_mapping, ) + self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype) @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(is_fbcode(), "broken in fbcode") diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 0526ee01b2..4567f3baef 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -21,6 +21,7 @@ ) from torchao.quantization.quant_primitives import ( MappingType, + ZeroPointDomain, ) @@ -74,7 +75,7 @@ def test_block_size_calc_success(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -93,7 +94,7 @@ def test_block_size_calc_success(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) for example_input in example_inputs: obs(example_input) @@ -108,7 +109,7 @@ def test_block_size_row_errors(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -127,7 +128,7 @@ def test_block_size_row_errors(self): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) example_inputs = [ torch.randn(10, 2048), @@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) if observe_weight: weight_observer = AffineQuantizedMinMaxObserver( @@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) else: weight_observer = None @@ -199,7 +200,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): input_scale.item(), max_val / max_fp8, ) - self.assertIsNotNone(input_zero_point) + self.assertIsNone(input_zero_point) if observe_weight: weight_observer = linear.weight.weight_observer @@ -210,7 +211,7 @@ def test_linear_observer_tensor(self, observe_weight: bool): atol=5e-5, rtol=0.0, ) - self.assertIsNotNone(weight_zero_point) + self.assertIsNone(weight_zero_point) else: self.assertIsNone(linear.weight.weight_observer) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 77616c1c6a..3ca58ff996 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -843,6 +843,55 @@ def test_fake_quantize_affine_cachemask(self): torch.testing.assert_close(dequantized, fake_quantized) torch.testing.assert_close(expected_mask, mask) + def test_none_zero_point_domain(self): + """A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should""" + input = torch.randn(10, 256) + mapping_type = MappingType.SYMMETRIC + dtype = torch.int8 + block_size = (1, 128) + quant_min = None + quant_max = None + eps = 1e-6 + scale_dtype = torch.float32 + zero_point_dtype = torch.int64 + try: + _, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=None, + ) + except ValueError: + # This exception was expected + # Now test for ZeroPointDomain.NONE + _, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=True, + zero_point_domain=ZeroPointDomain.NONE, + ) + self.assertTrue(zero_point is None) + else: + # An exception should have been thrown for zero_point_domain None + self.assertTrue( + False, + msg="A runtime exception should have been thrown for zero_point_domain None", + ) + @parameterized.expand( [ ( @@ -890,7 +939,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype): quant_min=torch.finfo(float8_dtype).min, quant_max=torch.finfo(float8_dtype).max, zero_point=None, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) expected_dequantized = dequantize_affine( expected_quantized, @@ -901,7 +950,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype): quant_min=torch.finfo(float8_dtype).min, quant_max=torch.finfo(float8_dtype).max, zero_point=None, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) self.assertTrue(torch.equal(expected_scale, scale)) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e3ac420de7..715aaeb9ec 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -81,6 +81,8 @@ def __new__( dtype=None, strides=None, ): + if zero_point_domain is None: + raise ValueError("please use ZeroPointDomain.NONE instead of None") kwargs = {} kwargs["device"] = tensor_impl.device kwargs["layout"] = ( @@ -199,7 +201,7 @@ def from_hp_to_intx( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), use_hqq: bool = False, ): @@ -258,8 +260,7 @@ def from_hp_to_intx( zero_point_domain, ) # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None - # TODO should probably consolidate ZeroPointDomain.NONE and None - if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: + if zero_point_domain == ZeroPointDomain.NONE: zero_point = None data = quantize_affine( input_float, @@ -296,14 +297,15 @@ def from_hp_to_intx_static( target_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Layout = PlainLayout(), ): """Create an integer AffineQuantizedTensor from a high precision tensor using static parameters.""" + if zero_point_domain is None: + raise ValueError("please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") if target_dtype not in FP8_TYPES: - assert ( - zero_point_domain is not None - ), "zero_point_domain must be specified for non-fp8 types" assert ( zero_point is not None ), "zero_point must be specified for non-fp8 types" @@ -359,7 +361,7 @@ def from_hp_to_floatx( scale_dtype=scale_dtype, zero_point_dtype=None, preserve_zero=True, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, _layout=_layout, use_hqq=False, ) @@ -387,7 +389,7 @@ def from_hp_to_floatx_static( target_dtype=target_dtype, quant_min=math.ceil(torch.finfo(target_dtype).min), quant_max=math.ceil(torch.finfo(target_dtype).max), - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, _layout=_layout, ) else: diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index 3a4253bb3f..95175caacf 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -54,10 +54,12 @@ def from_hp_to_intx( block_size: Tuple[int, ...], quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, _layout: Optional[Layout] = None, ): """Converts a floating point tensor to a Marlin QQQ quantized tensor.""" + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") original_shape = input_float.shape input_float = _layout.pre_process(input_float) nbits = int(math.log2(quant_max - quant_min + 1)) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 06509c7b91..cbbe1b581d 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -104,11 +104,12 @@ def __init__( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ): super().__init__() assert granularity is not None, "granularity is None" - + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") self.mapping_type = mapping_type self.target_dtype = target_dtype self.granularity = granularity diff --git a/torchao/quantization/qat/affine_fake_quantized_tensor.py b/torchao/quantization/qat/affine_fake_quantized_tensor.py index b84200ac9c..f60c858b73 100644 --- a/torchao/quantization/qat/affine_fake_quantized_tensor.py +++ b/torchao/quantization/qat/affine_fake_quantized_tensor.py @@ -41,6 +41,9 @@ def forward( preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> "AffineFakeQuantizedTensor": + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + def apply_fake_quant_fn(t: torch.Tensor): assert isinstance(t, AffineFakeQuantizedTensor) qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) @@ -158,6 +161,8 @@ def from_float( preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ): + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") return _ToAffineFakeQuantized.apply( original_input, mapping_type, diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index cd3813291f..925a0eed3c 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -96,6 +96,8 @@ def __init__( group_size: Optional[int] = None, is_symmetric: Optional[bool] = None, ): + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") self.dtype = dtype self.granularity = self._get_granularity(granularity, group_size) self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3a73b97ad1..02af4ced91 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -387,7 +387,7 @@ def insert_observers_( eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ) # Create a linear module @@ -688,7 +688,7 @@ def int4_weight_only( group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False, - zero_point_domain=None, + zero_point_domain=ZeroPointDomain.NONE, ): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using @@ -733,7 +733,7 @@ def apply_int4_weight_only_quant(weight): assert ( type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" - if zero_point_domain is None: + if zero_point_domain == ZeroPointDomain.NONE: # the first value is the default one zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] else: @@ -877,6 +877,7 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight): # weight settings mapping_type = MappingType.SYMMETRIC + weight_zero_point_domain = ZeroPointDomain.NONE def get_weight_block_size(x): return (1, x.shape[1]) @@ -903,6 +904,7 @@ def get_weight_block_size(x): eps=eps, zero_point_dtype=zero_point_dtype, _layout=layout, + zero_point_domain=weight_zero_point_domain, ) weight = to_linear_activation_quantized(weight, input_quant_func) return weight diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 8b0ce28434..05be8c5c30 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -284,7 +284,7 @@ def quantize_affine( output_dtype: torch.dtype, quant_min: Optional[Union[int, float]] = None, quant_max: Optional[Union[int, float]] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: """ Args: @@ -319,6 +319,10 @@ def quantize_affine( Output: quantized tensor with requested dtype """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") return _quantize_affine( input, block_size, @@ -327,7 +331,7 @@ def quantize_affine( output_dtype, quant_min, quant_max, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, ) @@ -340,7 +344,7 @@ def _quantize_affine( output_dtype: torch.dtype, quant_min: Optional[Union[int, float, bool]] = None, quant_max: Optional[Union[int, float, bool]] = None, - zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, + zero_point_domain: str = ZeroPointDomain.INT.name, ) -> torch.Tensor: """op definition that has compatible signatures with custom op library @@ -363,6 +367,7 @@ def _quantize_affine( zero_point, quant_min, quant_max, + output_dtype, zero_point_domain, ).to(output_dtype) @@ -374,7 +379,8 @@ def _quantize_affine_no_dtype_cast( zero_point: Optional[torch.Tensor], quant_min: Union[int, float], quant_max: Union[int, float], - zero_point_domain: Optional[str] = ZeroPointDomain.INT.name, + quant_dtype: torch.dtype, + zero_point_domain: str = ZeroPointDomain.INT.name, ) -> torch.Tensor: """ The op does the following: @@ -418,13 +424,12 @@ def _quantize_affine_no_dtype_cast( assert ( zero_point is None ), "zero_point should be None when zero_point_domain is NONE" - quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) - elif zero_point_domain is None: - # This case handles quantization for float8 we expect no zero point and no zero point domain - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is None" - quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + if _is_float8_type(quant_dtype): + quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + else: + quant = torch.clamp( + torch.round(input * (1.0 / scale)), quant_min, quant_max + ) else: assert zero_point_domain == ZeroPointDomain.FLOAT.name mid_point = (quant_max + quant_min + 1) / 2 @@ -470,6 +475,10 @@ def dequantize_affine( Output: dequantized Tensor, with requested dtype or fp32 """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") return _dequantize_affine( input, block_size, @@ -478,7 +487,7 @@ def dequantize_affine( input_dtype, quant_min, quant_max, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, output_dtype=output_dtype, ) @@ -567,16 +576,6 @@ def _dequantize_affine_no_dtype_check( ), "zero_point should be None when zero_point_domain is NONE" dequant = input.to(output_dtype) dequant = dequant * scale - elif zero_point_domain is None: - # This case handles dequantization for float8 we expect no zero point and no zero point domain - assert ( - zero_point is None - ), "zero_point should be None when zero_point_domain is None" - assert _is_float8_type( - input.dtype - ), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}" - dequant = input.to(output_dtype) - dequant = dequant * scale else: assert ( zero_point_domain == ZeroPointDomain.FLOAT.name @@ -624,6 +623,10 @@ def fake_quantize_affine( value during quantization default is ZeroPointDomain.INT """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") (_, fq) = _do_fake_quantize_affine( input, block_size, @@ -666,6 +669,10 @@ def fake_quantize_affine_cachemask( ) """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") + elif zero_point_domain is None and zero_point is not None: + raise ValueError("zero_point should be None when zero_point_domain is NONE") (q, dq) = _do_fake_quantize_affine( input, block_size, @@ -703,6 +710,7 @@ def _do_fake_quantize_affine( zero_point, quant_min, quant_max, + quant_dtype, zero_point_domain.name, ) dq = _dequantize_affine_no_dtype_check( @@ -730,7 +738,7 @@ def choose_qparams_affine( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -764,6 +772,8 @@ def choose_qparams_affine( Output: Tuple of scales and zero_points Tensor with requested dtype """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") return _choose_qparams_affine( input, mapping_type.name, @@ -775,7 +785,7 @@ def choose_qparams_affine( scale_dtype, zero_point_dtype, preserve_zero, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, ) @@ -791,7 +801,7 @@ def choose_qparams_affine_with_min_max( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` operator that pass in min_val and max_val directly instead of deriving these from a single input. @@ -803,6 +813,8 @@ def choose_qparams_affine_with_min_max( difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val and then scale/zero_point, we pass in min_val/max_val directly """ + if zero_point_domain is None: + raise ValueError("Please use ZeroPointDomain.NONE instead of None") return _choose_qparams_affine( None, mapping_type.name, @@ -814,7 +826,7 @@ def choose_qparams_affine_with_min_max( scale_dtype, zero_point_dtype, preserve_zero, - zero_point_domain.name if zero_point_domain is not None else None, + zero_point_domain.name, min_val, max_val, ) @@ -921,17 +933,17 @@ def _choose_qparams_affine( raise ValueError( "preserve_zero == False is not supported for symmetric quantization" ) - if ( - zero_point_domain is not None - and zero_point_domain == ZeroPointDomain.FLOAT.name - ): + if zero_point_domain == ZeroPointDomain.FLOAT.name: # TODO INT should not be a valid ZeroPointDomain for symmetric quantization since # symmetric quant doesn't have a zero_point raise ValueError( "zero_point_domain should be ZeroPointDomain.INT or ZeroPointDomain.NONE for symmetric quantization" ) + if zero_point_domain == ZeroPointDomain.NONE.name: + zero_point = None + else: + zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) scale = torch.clamp(scale, min=eps) - zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) From 2aed684cf368d2156d634d8e53333847ae4089b5 Mon Sep 17 00:00:00 2001 From: "Jane (Yuan) Xu" <31798555+janeyx99@users.noreply.github.com> Date: Wed, 29 Jan 2025 15:24:21 -0500 Subject: [PATCH 037/115] Pass all args to pytest.main to propagate user options like -k (#1640) Pass all args to pytest.main to propage user options like -k Tested locally with `python test/test_ops.py -k test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant` which previously just ran all the tests but after this PR will run 60, the same number as `pytest test/test_ops.py -k test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant` --- test/test_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 26671ddf40..54efefb026 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,4 +1,5 @@ import itertools +import sys import pytest import torch @@ -614,4 +615,4 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact if __name__ == "__main__": - pytest.main([__file__]) + pytest.main(sys.argv) From 2d8c8ebe17d8ce31f9ff847330fb6df6d3c5f875 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 29 Jan 2025 12:51:01 -0800 Subject: [PATCH 038/115] only run docs CI jobs on PRs when docs have changed (#1612) only run docs CI jobs when docs have changed --- .github/workflows/doc_build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/doc_build.yml b/.github/workflows/doc_build.yml index d16ed0340b..27ae54975d 100644 --- a/.github/workflows/doc_build.yml +++ b/.github/workflows/doc_build.yml @@ -9,10 +9,10 @@ on: tags: - v[0-9]+.[0-9]+.[0-9] - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + pull_request: paths: - 'docs/**' - '!docs/**' - pull_request: workflow_dispatch: concurrency: From 0c428237cb3334d2e23fb45c1e2504bf208f6ffe Mon Sep 17 00:00:00 2001 From: Hao Dong <60164894+haodongucsb@users.noreply.github.com> Date: Wed, 29 Jan 2025 15:35:26 -0800 Subject: [PATCH 039/115] Fix `.item()` issue in running parallel evaluation for BO mixed precision Differential Revision: D68726705 Pull Request resolved: https://github.com/pytorch/ao/pull/1630 --- .../mixed_precision/scripts/BO_acc_modelsize.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/quantization/mixed_precision/scripts/BO_acc_modelsize.py b/torchao/prototype/quantization/mixed_precision/scripts/BO_acc_modelsize.py index 1db980104c..df7f670b41 100644 --- a/torchao/prototype/quantization/mixed_precision/scripts/BO_acc_modelsize.py +++ b/torchao/prototype/quantization/mixed_precision/scripts/BO_acc_modelsize.py @@ -3,7 +3,6 @@ import torch import torch.multiprocessing as mp from ax.service.ax_client import AxClient, ObjectiveProperties -from BO_acc_throughput import define_parameter_list from utils import ( cal_model_size, cal_wikitext_ppl, @@ -174,12 +173,12 @@ def eval_in_parallel( model, tokenizer = load_model(checkpoint, f"cuda:{gpu_id}") print(f"Process {proc_id} on GPU {gpu_id} starts!") - + dict_config = dict(config) quantize_by_fqn_to_config( - model=model, device=f"cuda:{gpu_id}", fqn_to_config=dict(config) + model=model, device=f"cuda:{gpu_id}", fqn_to_config=dict_config ) - eval_results = eval(model, tokenizer, num_PPL_eval_samples, config) + eval_results = eval(model, tokenizer, num_PPL_eval_samples, dict_config) return_dict[proc_id] = (trial_id, config, eval_results) @@ -206,7 +205,7 @@ def run_parallel_BO( initial_samples, ): # TODO: add default parameter list if not specified - parameters_list = define_parameter_list() + parameters_list = load_parameters_from_json(parameters_list) initial_points_set = load_initial_samples(initial_samples) num_BO_initial_samples = len(initial_points_set) From aa0b7ca1942fb72e8056f2b033108e12016c7a98 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 29 Jan 2025 18:56:27 -0500 Subject: [PATCH 040/115] Split contributor guide into quantization overview (#1618) There's a lot of content in the contributor guide that belongs better to "Quantization Overview", so here we split the content and put them in the right pages. --- docs/source/contributor_guide.rst | 276 ++---------------------------- docs/source/quantization.rst | 241 +++++++++++++++++++++++++- 2 files changed, 251 insertions(+), 266 deletions(-) diff --git a/docs/source/contributor_guide.rst b/docs/source/contributor_guide.rst index 7d4d20cc65..ab6d433e27 100644 --- a/docs/source/contributor_guide.rst +++ b/docs/source/contributor_guide.rst @@ -1,261 +1,19 @@ Contributor Guide ------------------------- -.. toctree:: - :maxdepth: 3 - -Objective -========= -In this doc we’ll talk about -(1). How different optimization techniques are structured in torchao -(2). How to contribute to torchao - -Note: the doc is heavily focused on inference right now, but we plan to expand to cover training techniques in the future as well. - -torchao Stack Overview -====================== - -First we want to lay out the torchao stack:: - - Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc. - --------------------------------------------------------------------------------------------- - Quantized Tensors (derived dtypes): AffineQuantizedTensor, CodebookQuantizedTensor - --------------------------------------------------------------------------------------------- - Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize - --------------------------------------------------------------------------------------------- - Basic dtypes: uint1-uint7, int1-int8, float3-float8 - - -Any quantization algorithm will be using some components from the above stack, for example int4_weight_only quantization uses: -(1) weight only quantization flow -(2) `tinygemm bf16 activation + int4 weight kernel `__ and `quant primitive ops `__ -(3) `AffineQuantizedTensor `__ tensor subclass with `TensorCoreTiledLayout `__ -(4) torch.uint4 dtype (simulated with quant_min/quant_max right now) - -Note: we'll also talk about how to compose sparsity with quantization in the Quantized Tensors section - -Basic DTypes -~~~~~~~~~~~~ -`dtype `__ is a bit of overloaded term, by basic dtype, we mean the dtypes that makes sense without any extra metadata (e.g. makes sense when people call ``torch.empty(.., dtype)``), for more details please check out: dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833 - -No matter what quantization we are doing, in the end we will be using some low precision dtypes to represent the quantized data, the dtypes we aim to support in torchao are: - -* ``torch.uint1`` to ``torch.uint8`` available in pytorch 2.3 and later -* ``torch.int1`` to ``torch.int8`` available in pytorch 2.6 and later -* ``torch.float3_e2_m0``, ``torch.float4_e2_m1``, ``torch.float4_e3_m0``, ``torch.float5_e2_m2``, ``torch.float5_e3_m1``, ``torch.float6_e2_m3``, ``torch.float6_e3_m2``, ``torch.float8_e4m3fn``, ``torch.float8_e5m2``, ``torch.float8_e4m3fnuz``, ``torch.float8_e5m2fnuz`` (float8 is added to torch, we also plan to add float4 and float6 to torch if they become popular) - -Note some of the above are prototype only for now. We'll consider adding then to pytorch core when they become popular and have hardware support. - -Current Support -############### -In terms of actual implementation, there are two parts: -1). In PyTorch, we need to add the dtype to torch.dtype, e.g. torch.uint2, example: pytorch/pytorch#117208, but these are just placeholders so that we can use torch.uint2. -2). Outside of PyTorch (e.g. in torchao), we implement the tensor operations for these dtypes with tensor subclasses, also a standard packing format is needed. - -Adding placeholder dtype in PyTorch -*********************************** - -As mentioned in dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833, the criteria for adding dtype in PyTorch is that it shows wide adoption. For the above mentioned fundamental dtypes, the ones that are supported in PyTorch are: - -* ``torch.uint1`` to ``torch.uint8``, ``torch.int1`` to ``torch.int8``, ``torch.float8_e4m3fn``, ``torch.float8_e5m2``, ``torch.float8_e4m3fnuz``, ``torch.float8_e5m2fnuz`` - -For the other types we plan to wait until there is more evidence of wide adoption and hardware support. - -Implementing tensor operations for these dtypes with Tensor subclasses -********************************************************************** -For this, the requirement is we decide on a "standard" packing format, and hopefully one that is amenable to efficient implementation, but for both uintx and floatx we haven't integrate enough kernels to decide on this. So current `packing implementations `__ are ont final. We can revisit after there are more uintx, intx and floatx kernels being integrated into torchao. - -Integrate Tensor subclass to pytorch native factory functions -************************************************************* -After that we can connect the factory function with the tensor subclass, for example: ``torch.empty(..., dtype=torch.int4, ...)`` can create a ``Int4Tensor`` tensor subclass with the packing format decided in the previous step. - -Quantization Primitive Ops -~~~~~~~~~~~~~~~~~~~~~~~~~~ -Quantization primitive ops means the operators used to convert between low preicison quantized tensors and high precision tensors. We will mainly have the following quantization primitive operators: -choose_qparams ops: that chooses quantization parameter based on the original Tensor, typically used in dynamic quantization, e.g. scale and zero_point for affine quantization -quantize op: quantizes the original high precision tensor to the low precision tensor with the dtypes mentioned in previous section based on the quantization parameters -dequantize op: dequantizes the low precision tensor into the high precision tensor based on quantization parameters - -There could be variations of the above to accommodate specific use cases, for example for static quantization we may have ``choose_qparams_affine_with_min_max`` that will choose quantization parameters based on min/max values derived from the observation process. - -Efficient kernels -~~~~~~~~~~~~~~~~~ -We'll also have efficient kernels that works with the low precision tensors, for example - -`_weight_int4pack_mm `__ the tinygemm int4 kernel (bf16 activation + int4 weight) -`int_matmul `__ that takes two int8 tensors and outputs an int32 tensor -`int_scaled_matmul `__ that does matmul and also applies a scale to the result. - -Note: We can also rely on torch.compile to generate kernels (through triton), for example the current int8 weight only quantization `kernel `__ just relies on torch.compile to get speedup. In this case there is no specific "efficient kernel" that's corresponding to the type of quantization. - -Quantized Tensors (derived dtypes) +General Guide on Extending torchao ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -On top of the basic dtypes, quantization primitive operators and efficient kernels, we can glue everything together and build out a Quantized (low precision) Tensor by subclassing torch.Tensor that can be constructed from a high precision Tensor and some parameters that can configure the specific quantization user wants, we can also call this derived dtypes since it can be represented with Tensors of basic dtypes and some extra metadata like scale. - -Existing example in torchao is ``AffineQuantizedTensor``, meaning the low precision Tensor is quantized from the high precision Tensor by an affine mapping, that is: ``low_precision_val = high_precision_val / scale + zero_point``, where ``scale``/``zero_point`` are the quantization parameters that can be calculated by quantization primitive ops or through some optimization procedure. Affine quantization is a very common type of quantization, since it's straightforward that when we try to map from higher precision values to lower precision values, we do an affine transformation (``high_preicsion_val / scale + zero_point``). Another common type of quantization, especially for lower bitwidths (e.g. lower than 4 bit) is codebook / look up table based quantization. - -Layout and TensorImpl -##################### -Native tensors have a hardcoded list of selections of `layout `__, most common one is strided layout, it provides a strided, multi-dimensional view of storage, we also have some sparse and mkldnn layout. - -Take `sparse COO tensor `__ as an example, it has `torch.sparse_coo` layout, and `SparseTensorImpl `__ which changes how the tensor is stored. - -The idea of packing the tensor into different formats fits nicely with the layout concept, that’s why we want to reuse this for packing. We can use `Layout` for different type of packing format and `TensorImpl` for different storage format implementations. And new TensorImpl that stores the Tensor in a packed format can be added at python level tensor subclasses without modifying C++ pytorch core code. - -For example, for ``_weight_int4pack_mm`` we need to pack the weight to an format that is friendly for Tensor Core, we call it `TensorCoreTiledLayout `__. We add a ``tensor_impl`` for the quantized tensor to store the packed (or unpacked) weight, and we use ``layout`` to store different parameters that's relevant for packing:: - - class AffineQuantizedTensor(...): - # tensor_impl is also implemented with tensor subclass - tensor_impl: torch.Tensor - - # to not conflict with existing layout property, we use `_layout` - @property - def _layout(self) -> Layout: - return self.tensor_impl._layout - -Note that layout is an abstraction not only for custom data representation, it is also used for how the -`TensorImpl` interacts with different operators, e.g. the same data representation can have different -implementations when running the same operator, e.g. transpose, quantized_linear, but the operator semantics should stay the same. - -Quantize + Sparse Tensor can also be supported through the Layout abstraction, for example, `int4 weight only quantization + sparse `__. We also provide some common utils that helps people to add different layouts to a quantized tensor, please check out the developer guide below for code examples. - -Quantization Algorithms/Flows -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -On the top of the stack will be the final quantization algorithms and quantization flows. Traditionally we have weight only quantization, dynamic quantization and static quantization, but now we are also seeing more types of quantization coming up. - -For demonstration purposes, let's say after previous step we have ``AffineQuantizedTensor`` and ``to_affine_quantized`` factory function defined. For simplicity, let's say ``to_affine_quantized`` takes a high precision floating point Tensor and a target_dtype (e.g. torch.int8) and converts it to an ``AffineQuantizedTensor`` with corresponding dtype. - -Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in the `Writing Your Own Tensor Subclass `__ tutorial. - -Weight Only Quantization -######################## -This is the simplest form of quantization and it's easy to apply weight only quantization to the model, especially since we have Quantized Tensor. all we need to do is:: - linear_module.weight = torch.nn.Parameter(to_affine_quantized_intx(linear_module.weight, ...), requires_grad=False)) - -apply the above to all linear modules in the model and we'll get a weight only quantized model. - -Dynamic Activation and Weight Quantization -########################################## - -This is called "dynamic quantization" before but it means we quantize activation dynamically at runtime, and also quantize the weights as well. Compared to the weight only quantization, the main question is how do we apply the quantization to activation. In torchao, the common pattern we use is by applying ``to_linear_activation_quantized`` on top of quantized weight:: - quantized_weight = to_affine_quantized(linear_module.weight) - activation_and_weight_quantized = to_linear_activation_quantized(quantized_weight) - linear_module.weight = torch.nn.Parameter(activation_and_weight_quantized, requires_grad=False)) - -``to_linear_activation_quantized`` is used to apply quantization to activation, it takes a ``input_quant_func`` that will quantize the activation and the original weight, and during runtime when it encounters a ``F.linear`` op, it will apply the stored input_qunat_func to activation and redispatch to ``F.linear`` with quantized activation and weight. - -If the above does not work, user can also do module swaps, or use ``torch.fx.symbolic_trace()`` to get a traced module that you can `modify `__. - -But using tensor subclass is preferred because it is easier for serialization/deserialization, if we use tensor subclasses to support dynamic quantization, then we can load the quantized weights directly without further preparation for the model. Otherwise, we'd need to do module swap or other modifications to the model first before loading the quantized weights. - -Static Activation Quantization and Weight Quantization -###################################################### -Static quantization means activation is statically quantized instead of dynamically quantized at runtime. In terms of flow, static quantization requires calibration with sample data in order that we can figure out the appropriate quantization parameters. - -At the high level there are three steps for static quantization: (1) insert observers (2) calibration (3) quantize the model - -Insert Observers -**************** -In insert observers step, we need to add observer modules to input (and output) activation and weight of the operator to collect statistics of the Tensor. So there are two things we need to address, how to define observer module? how to add observer module to the model. +For a new use case, for example, a training dtype (like fp4 training), it's fine to start with adding a new tensor subclass in prototype folder `torchao/prototype `__, but you could also take a look at ``AffineQuantizedTensor`` if what you want to do is mostly supported there, e.g. adding int3 kernel for the exact same affine quantization. Please feel free to open an issue and if you have questions on what to do for a specific new use case. For more details, please refer to our `quantization overview page `__. -How to define observer module -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Observers are specific to: (1) type of quantization (e.g. affine quantization, look up table based quantization) (2) type of stats we want to track, e.g. min max observer, moving average observer. - -Generally an observer module should define `forward `__ and `calculate_qparams `__ - -For affine quantization, we defined `AffineQuantizedMinMaxObserver `__ that records min_val/max_val based on the granularity of affine quantization, and also defines how to calculate_qparams based on the recorded stats. - -How to add observer module to the model -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -1. Use Tensor Subclasses - If the only operator you are interested in quantizing is linear, you can use `linear activation weight observer `__, we also have a corresponding `insert_observer_ `__ API that handles modifying the weight of linear. - -2. Module Swap - Alternatively, you could also define and `ObservedLinear `__ module (or other module types) and swap the non observed with the observed module - -Calibration -^^^^^^^^^^^ -Calibration step is typically straightforward, typically we just need to run the model through the calibration dataset. For more complicated calibration (e.g. where we record all inputs and do optimizations based on all inputs), we'll cover some of them in next section. - -Quantize -^^^^^^^^ -We can reuse the ``quantize_`` API but provide a different ``apply_tensor_subclass`` function that converts the observed linear module to a linear module with quantized weight and statically quantized input activation, this can be done in the same manner as the dynamic quantization (with ``to_linear_activation_quantized``), see `example `__. - -Alternatively, user can do `module swap `__ as well. - -Other Quantization Flows -######################## - -For other quantization flow/algorithms that does not fit into any of the above, we also intend to provide examples for common patterns. For example, `GPTQ like quantization flow `__ that is adopted by `Autoround `__, it uses `MultiTensor `__ and module hooks to optimize the module. - -If you are working on a new quantization algorithm/flow and not sure how to implement it in a PyTorch native way, please feel free to open an issue to describe how your algorithm works and we can help advise on the implementation details. - -Training -######## -The above flow are mainly focused on inference, but low bit dtype Tensors can be used in training as well. - -Quantization Aware Training -*************************** -TODO - - -Low Bit Optimizers -****************** -Today we have some prototype low bit optimizers: `main/torchao/prototype/low_bit_optim `__ that implements a specific type of 4 bit, 8 bit and float8, and is also composable with FSDP (with look up table quantization). - -Quantized Training -****************** -Similar to low bit optimizers, we have quantized training prototype in `main/torchao/prototype/quantized_training `__, and we could extend AffineQuantizedTensor to support training as well, initial enablement is in progress, but there will be a lot of follow up work needed including making it work for different kernels etc. - -You can also checkout the tutorial for `Quantized Training `__ that talks about how to make a dtype tensor subclass trainable. - -Case Study: How int4 weight only quantization works in torchao? -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To connect everything together, here is a more detailed walk through for how int4 weight only quantization is implemented in torchao. - -High Level Summary -################## - -:: - Quantization Flow: quantize_(model, int4_weight_only()) - * What happens: linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight), requires_grad=False) - * quantization primitive ops: choose_qparams and quantize_affine are called to quantize the Tensor - * quantized Tensor will be `AffineQuantizedTensor`, a quantized tensor with derived dtype (e.g. int4 with scale and zero_point) - * packing op `_convert_weight_to_int4pack` to pack the quantized weight for efficient execution - - During Model Execution: model(input) - * `torch.ops.aten._weight_int4pack_mm` is called on input and the packed weight - -During Quantization -################### -First we start with the API call: ``quantize_(model, int4_weight_only())`` what this does is it converts the weights of nn.Linear modules in the model to int4 quantized tensor (``AffineQuantizedTensor`` that is int4 dtype, asymmetric, per group quantized), using the layout for tinygemm kernel: ``tensor_core_tiled`` layout. - -* `quantize_ `__: the model level API that quantizes the weight of linear by applying the conversion function from user (second argument) -* `int4_weight_only `__: the function that returns a function that converts weight of linear to int4 weight only quantized weight - * Calls quantization primitives ops like choose_qparams_affine and quantize_affine to quantize the Tensor -* `TensorCoreTiledLayout `__: the tensor core tiled layout type, storing parameters for the packing format -* `TensorCoreTiledAQTTensorImpl `__: the tensor core tiled TensorImpl, stores the packed weight for efficient int4 weight only kernel (tinygemm kernel) - -During Model Execution -###################### - -When we run the quantized model ``model(inputs)``, we'll run through the functional linear operator in nn.Linear:: - return F.linear(input, weight, bias) - -where input is a ``bfloat16`` Tensor, weight is an int4 ``AffineQuantizedTensor``, it calls into a ``__torch_function__`` of the ``AffineQuantizedTensor`` subclass, which will end up in an implementation for ``F.linear`` when one of the input is ``AffineQuantizedTensor``, so it calls:: - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - -The ``_quantized_linear_op`` goes through the ``_AQT_QLINEAR_DISPATCH_TABLE`` and checks each dispatch conditions, if the dispatch condition passes, it will call the implementation with ``input``/``weight``/``bias``. Please check out `this doc `__ for the explanation of ``dispatch_condition`` and ``impl``. - -int4 weight only `dispatch_condition `__ checks if the input is ``bfloat16`` Tensor and weight is a uint4 ``AffineQuantizedTensor`` -wint4 weight only quantization `kernel implementation `__ takes an bfloat16 input Tensor and an int4 AffineQuantizedTensor, and call ``torch.ops.aten._weight_int4pack_mm`` with the input Tensor and the packed weight that's stored in ``weight_tensor.tensor_impl``. - -During Save/Load -################ +To contribute to existing code base: -Since ``AffineQuantizedTensor`` weight is still a ``torch.Tensor``, save/load works the same way as the original high precision floating point model. See the `serialization doc `__ for more details. +* Adding features to AffineQuantizedTensor, e.g. making it trainable, add tensor parallelism support etc.: `torchao/dtypes/affine_quantized_tensor.py `__ +* Adding new quantization APIs: `torchao/quantization/quant_api.py `__ +* Adding new quantization primitive ops, e.g. slight variations of existing quantization primitive ops: `torchao/quantization/quant_primitives.py `__ +* Adding new autotuned triton kernels: `torchao/kernel `__ +* Adding new custom cpu/cuda/mps kernels: `torchao/csrc `__ +* Integrating custom kernel with AffineQuantizedTensor (maybe a new layout as well): Add sparse marlin AQT layout `#621 `__ as an example. We are still not decided if we want to split ``AffineQuantizedTensor`` to more tensor subclasses or not. Adding Efficient Kernels ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -331,20 +89,6 @@ The above just talks about basic feature support, we also provide examples on ho * [TODO] QAT -General Guide on Extending torchao -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -For a new use case, for example, a training dtype (like fp4 training), it's fine to start with adding a new tensor subclass in prototype folder `torchao/prototype `__, but you could also take a look at ``AffineQuantizedTensor`` if what you want to do is mostly supported there, e.g. adding int3 kernel for the exact same affine quantization. Please feel free to open an issue and if you have questions on what to do for a specific new use case. - -To contribute to existing code base: - -* Adding features to AffineQuantizedTensor, e.g. making it trainable, add tensor parallelism support etc.: `torchao/dtypes/affine_quantized_tensor.py `__ -* Adding new quantization APIs: `torchao/quantization/quant_api.py `__ -* Adding new quantization primitive ops, e.g. slight variations of existing quantization primitive ops: `torchao/quantization/quant_primitives.py `__ -* Adding new autotuned triton kernels: `torchao/kernel `__ -* Adding new custom cpu/cuda/mps kernels: `torchao/csrc `__ -* Integrating custom kernel with AffineQuantizedTensor (maybe a new layout as well): Add sparse marlin AQT layout `#621 `__ as an example. We are still not decided if we want to split ``AffineQuantizedTensor`` to more tensor subclasses or not. - Tensor Subclass Functionality/Composability Testing ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -359,9 +103,11 @@ Kernel Microbenchmarks Before we test performance on models, we can also do some microbenchmarks on single linear operator (or other compute intensive/memory intensive) operators with different input dimensions to get a sense of speedup. For a specific kernel that you'd like to benchmark, you can create a benchmark file like `benchmarks/benchmark_aq.py `__ and run benchmark with different shapes that's important for target model. A quick way to get the relevant shape for linear op and other ops is by running the example with `this `__. Change the model with the model you are interested in optimizing, and run the following:: + python tutorials/developer_api_guide/print_op_and_shapes.py Example output:: + TORCH_FUNC= (M, K, N): 10 10 10 TORCH_FUNC= args[0] shape: torch.Size([10, 10]) diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index b5e34780b7..958325280b 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -1,4 +1,243 @@ Quantization Overview --------------------- -Coming soon! +First we want to lay out the torchao stack:: + + Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc. + --------------------------------------------------------------------------------------------- + Quantized Tensors (derived dtypes): AffineQuantizedTensor, CodebookQuantizedTensor + --------------------------------------------------------------------------------------------- + Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize + --------------------------------------------------------------------------------------------- + Basic dtypes: uint1-uint7, int1-int8, float3-float8 + + +Any quantization algorithm will be using some components from the above stack, for example int4_weight_only quantization uses: +(1) weight only quantization flow +(2) `tinygemm bf16 activation + int4 weight kernel `__ and `quant primitive ops `__ +(3) `AffineQuantizedTensor `__ tensor subclass with `TensorCoreTiledLayout `__ +(4) torch.uint4 dtype (simulated with quant_min/quant_max right now) + +Note: we'll also talk about how to compose sparsity with quantization in the Quantized Tensors section + +Basic DTypes +~~~~~~~~~~~~ +`dtype `__ is a bit of overloaded term, by basic dtype, we mean the dtypes that makes sense without any extra metadata (e.g. makes sense when people call ``torch.empty(.., dtype)``), for more details please check out: dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833 + +No matter what quantization we are doing, in the end we will be using some low precision dtypes to represent the quantized data, the dtypes we aim to support in torchao are: + +* ``torch.uint1`` to ``torch.uint8`` available in pytorch 2.3 and later +* ``torch.int1`` to ``torch.int8`` available in pytorch 2.6 and later +* ``torch.float3_e2_m0``, ``torch.float4_e2_m1``, ``torch.float4_e3_m0``, ``torch.float5_e2_m2``, ``torch.float5_e3_m1``, ``torch.float6_e2_m3``, ``torch.float6_e3_m2``, ``torch.float8_e4m3fn``, ``torch.float8_e5m2``, ``torch.float8_e4m3fnuz``, ``torch.float8_e5m2fnuz`` (float8 is added to torch, we also plan to add float4 and float6 to torch if they become popular) + +Note some of the above are prototype only for now. We'll consider adding then to pytorch core when they become popular and have hardware support. + +Current Support +############### +In terms of actual implementation, there are two parts: +1). In PyTorch, we need to add the dtype to torch.dtype, e.g. torch.uint2, example: pytorch/pytorch#117208, but these are just placeholders so that we can use torch.uint2. +2). Outside of PyTorch (e.g. in torchao), we implement the tensor operations for these dtypes with tensor subclasses, also a standard packing format is needed. + +Adding placeholder dtype in PyTorch +*********************************** + +As mentioned in dev-discuss.pytorch.org/t/supporting-new-dtypes-in-pytorch/1833, the criteria for adding dtype in PyTorch is that it shows wide adoption. For the above mentioned fundamental dtypes, the ones that are supported in PyTorch are: + +* ``torch.uint1`` to ``torch.uint8``, ``torch.int1`` to ``torch.int8``, ``torch.float8_e4m3fn``, ``torch.float8_e5m2``, ``torch.float8_e4m3fnuz``, ``torch.float8_e5m2fnuz`` + +For the other types we plan to wait until there is more evidence of wide adoption and hardware support. + +Implementing tensor operations for these dtypes with Tensor subclasses +********************************************************************** +For this, the requirement is we decide on a "standard" packing format, and hopefully one that is amenable to efficient implementation, but for both uintx and floatx we haven't integrate enough kernels to decide on this. So current `packing implementations `__ are ont final. We can revisit after there are more uintx, intx and floatx kernels being integrated into torchao. + +Integrate Tensor subclass to pytorch native factory functions +************************************************************* +After that we can connect the factory function with the tensor subclass, for example: ``torch.empty(..., dtype=torch.int4, ...)`` can create a ``Int4Tensor`` tensor subclass with the packing format decided in the previous step. + +Quantization Primitive Ops +~~~~~~~~~~~~~~~~~~~~~~~~~~ +Quantization primitive ops means the operators used to convert between low preicison quantized tensors and high precision tensors. We will mainly have the following quantization primitive operators: +choose_qparams ops: that chooses quantization parameter based on the original Tensor, typically used in dynamic quantization, e.g. scale and zero_point for affine quantization +quantize op: quantizes the original high precision tensor to the low precision tensor with the dtypes mentioned in previous section based on the quantization parameters +dequantize op: dequantizes the low precision tensor into the high precision tensor based on quantization parameters + +There could be variations of the above to accommodate specific use cases, for example for static quantization we may have ``choose_qparams_affine_with_min_max`` that will choose quantization parameters based on min/max values derived from the observation process. + +Efficient kernels +~~~~~~~~~~~~~~~~~ +We'll also have efficient kernels that works with the low precision tensors, for example + +`_weight_int4pack_mm `__ the tinygemm int4 kernel (bf16 activation + int4 weight) +`int_matmul `__ that takes two int8 tensors and outputs an int32 tensor +`int_scaled_matmul `__ that does matmul and also applies a scale to the result. + +Note: We can also rely on torch.compile to generate kernels (through triton), for example the current int8 weight only quantization `kernel `__ just relies on torch.compile to get speedup. In this case there is no specific "efficient kernel" that's corresponding to the type of quantization. + +Quantized Tensors (derived dtypes) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +On top of the basic dtypes, quantization primitive operators and efficient kernels, we can glue everything together and build out a Quantized (low precision) Tensor by subclassing torch.Tensor that can be constructed from a high precision Tensor and some parameters that can configure the specific quantization user wants, we can also call this derived dtypes since it can be represented with Tensors of basic dtypes and some extra metadata like scale. + +Existing example in torchao is ``AffineQuantizedTensor``, meaning the low precision Tensor is quantized from the high precision Tensor by an affine mapping, that is: ``low_precision_val = high_precision_val / scale + zero_point``, where ``scale``/``zero_point`` are the quantization parameters that can be calculated by quantization primitive ops or through some optimization procedure. Affine quantization is a very common type of quantization, since it's straightforward that when we try to map from higher precision values to lower precision values, we do an affine transformation (``high_preicsion_val / scale + zero_point``). Another common type of quantization, especially for lower bitwidths (e.g. lower than 4 bit) is codebook / look up table based quantization. + +Layout and TensorImpl +##################### +Native tensors have a hardcoded list of selections of `layout `__, most common one is strided layout, it provides a strided, multi-dimensional view of storage, we also have some sparse and mkldnn layout. + +Take `sparse COO tensor `__ as an example, it has `torch.sparse_coo` layout, and `SparseTensorImpl `__ which changes how the tensor is stored. + +The idea of packing the tensor into different formats fits nicely with the layout concept, that’s why we want to reuse this for packing. We can use `Layout` for different type of packing format and `TensorImpl` for different storage format implementations. And new TensorImpl that stores the Tensor in a packed format can be added at python level tensor subclasses without modifying C++ pytorch core code. + +For example, for ``_weight_int4pack_mm`` we need to pack the weight to an format that is friendly for Tensor Core, we call it `TensorCoreTiledLayout `__. We add a ``tensor_impl`` for the quantized tensor to store the packed (or unpacked) weight, and we use ``layout`` to store different parameters that's relevant for packing:: + + class AffineQuantizedTensor(...): + # tensor_impl is also implemented with tensor subclass + tensor_impl: torch.Tensor + + # to not conflict with existing layout property, we use `_layout` + @property + def _layout(self) -> Layout: + return self.tensor_impl._layout + +Note that layout is an abstraction not only for custom data representation, it is also used for how the +`TensorImpl` interacts with different operators, e.g. the same data representation can have different +implementations when running the same operator, e.g. transpose, quantized_linear, but the operator semantics should stay the same. + +Quantize + Sparse Tensor can also be supported through the Layout abstraction, for example, `int4 weight only quantization + sparse `__. We also provide some common utils that helps people to add different layouts to a quantized tensor, please check out the developer guide below for code examples. + +Quantization Algorithms/Flows +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +On the top of the stack will be the final quantization algorithms and quantization flows. Traditionally we have weight only quantization, dynamic quantization and static quantization, but now we are also seeing more types of quantization coming up. + +For demonstration purposes, let's say after previous step we have ``AffineQuantizedTensor`` and ``to_affine_quantized`` factory function defined. For simplicity, let's say ``to_affine_quantized`` takes a high precision floating point Tensor and a target_dtype (e.g. torch.int8) and converts it to an ``AffineQuantizedTensor`` with corresponding dtype. + +Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in ``Tensor Subclass Developer Guide`` section. + +Weight Only Quantization +######################## +This is the simplest form of quantization and it's easy to apply weight only quantization to the model, especially since we have Quantized Tensor. all we need to do is:: + linear_module.weight = torch.nn.Parameter(to_affine_quantized_intx(linear_module.weight, ...), requires_grad=False)) + +apply the above to all linear modules in the model and we'll get a weight only quantized model. + +Dynamic Activation and Weight Quantization +########################################## + +This is called "dynamic quantization" before but it means we quantize activation dynamically at runtime, and also quantize the weights as well. Compared to the weight only quantization, the main question is how do we apply the quantization to activation. In torchao, the common pattern we use is by applying ``to_linear_activation_quantized`` on top of quantized weight:: + quantized_weight = to_affine_quantized(linear_module.weight) + activation_and_weight_quantized = to_linear_activation_quantized(quantized_weight) + linear_module.weight = torch.nn.Parameter(activation_and_weight_quantized, requires_grad=False)) + +``to_linear_activation_quantized`` is used to apply quantization to activation, it takes a ``input_quant_func`` that will quantize the activation and the original weight, and during runtime when it encounters a ``F.linear`` op, it will apply the stored input_qunat_func to activation and redispatch to ``F.linear`` with quantized activation and weight. + +If the above does not work, user can also do module swaps, or use ``torch.fx.symbolic_trace()`` to get a traced module that you can `modify `__. + +But using tensor subclass is preferred because it is easier for serialization/deserialization, if we use tensor subclasses to support dynamic quantization, then we can load the quantized weights directly without further preparation for the model. Otherwise, we'd need to do module swap or other modifications to the model first before loading the quantized weights. + +Static Activation Quantization and Weight Quantization +###################################################### +Static quantization means activation is statically quantized instead of dynamically quantized at runtime. In terms of flow, static quantization requires calibration with sample data in order that we can figure out the appropriate quantization parameters. + +At the high level there are three steps for static quantization: (1) insert observers (2) calibration (3) quantize the model + + +Insert Observers +**************** +In insert observers step, we need to add observer modules to input (and output) activation and weight of the operator to collect statistics of the Tensor. So there are two things we need to address, how to define observer module? how to add observer module to the model. + +How to define observer module +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Observers are specific to: (1) type of quantization (e.g. affine quantization, look up table based quantization) (2) type of stats we want to track, e.g. min max observer, moving average observer. + +Generally an observer module should define `forward `__ and `calculate_qparams `__ + +For affine quantization, we defined `AffineQuantizedMinMaxObserver `__ that records min_val/max_val based on the granularity of affine quantization, and also defines how to calculate_qparams based on the recorded stats. + +How to add observer module to the model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +1. Use Tensor Subclasses + If the only operator you are interested in quantizing is linear, you can use `linear activation weight observer `__, we also have a corresponding `insert_observer_ `__ API that handles modifying the weight of linear. + +2. Module Swap + Alternatively, you could also define and `ObservedLinear `__ module (or other module types) and swap the non observed with the observed module + +Calibration +^^^^^^^^^^^ +Calibration step is typically straightforward, typically we just need to run the model through the calibration dataset. For more complicated calibration (e.g. where we record all inputs and do optimizations based on all inputs), we'll cover some of them in next section. + +Quantize +^^^^^^^^ +We can reuse the ``quantize_`` API but provide a different ``apply_tensor_subclass`` function that converts the observed linear module to a linear module with quantized weight and statically quantized input activation, this can be done in the same manner as the dynamic quantization (with ``to_linear_activation_quantized``), see `example `__. + +Alternatively, user can do `module swap `__ as well. + +Other Quantization Flows +######################## + +For other quantization flow/algorithms that does not fit into any of the above, we also intend to provide examples for common patterns. For example, `GPTQ like quantization flow `__ that is adopted by `Autoround `__, it uses `MultiTensor `__ and module hooks to optimize the module. + +If you are working on a new quantization algorithm/flow and not sure how to implement it in a PyTorch native way, please feel free to open an issue to describe how your algorithm works and we can help advise on the implementation details. + +Training +######## +The above flow are mainly focused on inference, but low bit dtype Tensors can be used in training as well. + +Quantization Aware Training +*************************** +TODO + + +Low Bit Optimizers +****************** +Today we have some prototype low bit optimizers: `main/torchao/prototype/low_bit_optim `__ that implements a specific type of 4 bit, 8 bit and float8, and is also composable with FSDP (with look up table quantization). + +Quantized Training +****************** +Similar to low bit optimizers, we have quantized training prototype in `main/torchao/prototype/quantized_training `__, and we could extend AffineQuantizedTensor to support training as well, initial enablement is in progress, but there will be a lot of follow up work needed including making it work for different kernels etc. + +You can also checkout the tutorial for `Quantized Training `__ that talks about how to make a dtype tensor subclass trainable. + +Case Study: How int4 weight only quantization works in torchao? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +To connect everything together, here is a more detailed walk through for how int4 weight only quantization is implemented in torchao. + +Quantization Flow: quantize_(model, int4_weight_only()) + * What happens: linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight), requires_grad=False) + * quantization primitive ops: choose_qparams and quantize_affine are called to quantize the Tensor + * quantized Tensor will be `AffineQuantizedTensor`, a quantized tensor with derived dtype (e.g. int4 with scale and zero_point) + * packing op `_convert_weight_to_int4pack` to pack the quantized weight for efficient execution + +During Model Execution: model(input) + * `torch.ops.aten._weight_int4pack_mm` is called on input and the packed weight + +During Quantization +################### +First we start with the API call: ``quantize_(model, int4_weight_only())`` what this does is it converts the weights of nn.Linear modules in the model to int4 quantized tensor (``AffineQuantizedTensor`` that is int4 dtype, asymmetric, per group quantized), using the layout for tinygemm kernel: ``tensor_core_tiled`` layout. + +* `quantize_ `__: the model level API that quantizes the weight of linear by applying the conversion function from user (second argument) +* `int4_weight_only `__: the function that returns a function that converts weight of linear to int4 weight only quantized weight + * Calls quantization primitives ops like choose_qparams_affine and quantize_affine to quantize the Tensor +* `TensorCoreTiledLayout `__: the tensor core tiled layout type, storing parameters for the packing format +* `TensorCoreTiledAQTTensorImpl `__: the tensor core tiled TensorImpl, stores the packed weight for efficient int4 weight only kernel (tinygemm kernel) + +During Model Execution +###################### + +When we run the quantized model ``model(inputs)``, we'll run through the functional linear operator in nn.Linear:: + + return F.linear(input, weight, bias) + +where input is a ``bfloat16`` Tensor, weight is an int4 ``AffineQuantizedTensor``, it calls into a ``__torch_function__`` of the ``AffineQuantizedTensor`` subclass, which will end up in an implementation for ``F.linear`` when one of the input is ``AffineQuantizedTensor``, so it calls:: + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + +The ``_quantized_linear_op`` goes through the ``_AQT_QLINEAR_DISPATCH_TABLE`` and checks each dispatch conditions, if the dispatch condition passes, it will call the implementation with ``input``/``weight``/``bias``. Please check out `this doc `__ for the explanation of ``dispatch_condition`` and ``impl``. + +int4 weight only `dispatch_condition `__ checks if the input is ``bfloat16`` Tensor and weight is a uint4 ``AffineQuantizedTensor`` +wint4 weight only quantization `kernel implementation `__ takes an bfloat16 input Tensor and an int4 AffineQuantizedTensor, and call ``torch.ops.aten._weight_int4pack_mm`` with the input Tensor and the packed weight that's stored in ``weight_tensor.tensor_impl``. + +During Save/Load +################ + +Since ``AffineQuantizedTensor`` weight is still a ``torch.Tensor``, save/load works the same way as the original high precision floating point model. See the `serialization doc `__ for more details. + + From c1f5872d05a0b7c7c589c5de65eeb6262640ef92 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 29 Jan 2025 18:56:46 -0500 Subject: [PATCH 041/115] Update api_ref_quantization docs (#1619) --- docs/source/api_ref_quantization.rst | 46 +++++++++++++++++++++++----- docs/source/api_ref_sparsity.rst | 6 ++-- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index 7f2b312e85..a13cd54450 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -6,24 +6,43 @@ torchao.quantization .. currentmodule:: torchao.quantization +Main Quantization APIs +---------------------- + .. autosummary:: :toctree: generated/ :nosignatures: - autoquant quantize_ - int8_dynamic_activation_int4_weight - int8_dynamic_activation_int8_weight + autoquant + +Quantization APIs for quantize_ +------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + int4_weight_only int8_weight_only + int8_dynamic_activation_int4_weight + int8_dynamic_activation_int8_weight + uintx_weight_only + gemlite_uintx_weight_only + intx_quantization_aware_training + from_intx_quantization_aware_training float8_weight_only float8_dynamic_activation_float8_weight float8_static_activation_float8_weight - uintx_weight_only fpx_weight_only - to_linear_activation_quantized - swap_linear_with_smooth_fq_linear - smooth_fq_linear_to_inference + +Quantization Primitives +----------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + choose_qparams_affine choose_qparams_affine_with_min_max choose_qparams_affine_floatx @@ -40,3 +59,16 @@ torchao.quantization ZeroPointDomain TorchAODType +.. + TODO: delete these? + +Other +----- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + to_linear_activation_quantized + swap_linear_with_smooth_fq_linear + smooth_fq_linear_to_inference diff --git a/docs/source/api_ref_sparsity.rst b/docs/source/api_ref_sparsity.rst index 33c652390d..96b33af082 100644 --- a/docs/source/api_ref_sparsity.rst +++ b/docs/source/api_ref_sparsity.rst @@ -10,9 +10,9 @@ torchao.sparsity :toctree: generated/ :nosignatures: - WandaSparsifier - PerChannelNormObserver - apply_fake_sparsity sparsify_ semi_sparse_weight int8_dynamic_activation_int8_semi_sparse_weight + apply_fake_sparsity + WandaSparsifier + PerChannelNormObserver From b559c6deaf24e6ca3c1de151ffc8ff8a0e2710f3 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 29 Jan 2025 21:12:22 -0600 Subject: [PATCH 042/115] [Experimental][Kleidi] Add GEMM operator tests (#1638) --- .../kernels/cpu/aarch64/CMakeLists.txt | 4 +- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 2 +- ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 2 +- ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h | 2 +- ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h | 2 +- torchao/experimental/ops/tests/CMakeLists.txt | 22 + .../ops/tests/build_and_run_tests.sh | 41 +- .../experimental/ops/tests/generate_tests.py | 128 ++ .../test_linear_8bit_act_xbit_weight.cpp | 1467 ++++++++++++++++- 9 files changed, 1623 insertions(+), 47 deletions(-) create mode 100755 torchao/experimental/ops/tests/generate_tests.py diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index 8751c38c81..bb4d9ac22f 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -16,10 +16,10 @@ if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUA include(FetchContent) # KleidiAI is an open-source library that provides optimized # performance-critical routines, also known as micro-kernels, for artificial - # intelligence (AI) workloads tailored for Arm® CPUs. + # intelligence (AI) workloads tailored for Arm® CPUs. FetchContent_Declare(kleidiai GIT_REPOSITORY https://git.gitlab.arm.com/kleidi/kleidiai.git - GIT_TAG 35e156d62d1d7e4d27a39f56ed7770a665628b31) # same as xnnpack for now, TODO - revisit this + GIT_TAG v1.2.0) FetchContent_MakeAvailable(kleidiai) # Temporarily exposing this to the parent scope until we wire diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h index dbda036efd..658a0feadc 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h @@ -108,7 +108,7 @@ void kernel( activation_data, weight_data, output, - /*dst_stride_row=*/n * sizeof(float), + /*dst_stride_row=*/output_m_stride * sizeof(float), /*dst_stride_col=*/sizeof(float), clamp_min, clamp_max); diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h index d3d7bd55d9..336d5a8e7f 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h @@ -109,7 +109,7 @@ void kernel( activation_data, weight_data, output, - /*dst_stride_row=*/ n * sizeof(float), + /*dst_stride_row=*/ output_m_stride * sizeof(float), /*dst_stride_col=*/ sizeof(float), clamp_min, clamp_max); diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h index 4ef499d72c..60004704ed 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h @@ -106,7 +106,7 @@ void kernel( activation_data, weight_data, output, - /*dst_stride_row=*/n * sizeof(float), + /*dst_stride_row=*/output_m_stride * sizeof(float), /*dst_stride_col=*/sizeof(float), clamp_min, clamp_max); diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h index d898cf3e5b..90db4ae3d6 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h @@ -107,7 +107,7 @@ void kernel( activation_data, weight_data, output, - /*dst_stride_row=*/n * sizeof(float), + /*dst_stride_row=*/output_m_stride * sizeof(float), /*dst_stride_col=*/sizeof(float), clamp_min, clamp_max); diff --git a/torchao/experimental/ops/tests/CMakeLists.txt b/torchao/experimental/ops/tests/CMakeLists.txt index ff41ad45b3..c3d34d6ba9 100644 --- a/torchao/experimental/ops/tests/CMakeLists.txt +++ b/torchao/experimental/ops/tests/CMakeLists.txt @@ -25,12 +25,34 @@ if(TORCHAO_BUILD_KLEIDIAI) add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1) endif() +if(TORCHAO_BUILD_ARM_I8MM) + add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM) +endif() + +if (ANDROID_ABI) + # We are cross compiling, delay test discovery till runtime + set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST) +endif() + include_directories(${TORCHAO_INCLUDE_DIRS}) set(TORCHAO_PARALLEL_BACKEND "test_dummy") add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) include(${TORCHAO_ROOT}/Utils.cmake) + +if (ANDROID_ABI) + # Given where we are today this is sufficent. But needs to be revisited. + # This is also needed for native builds, but keeping it only for cross builds + # for now given the hacky nature. + file(GLOB DOTPROD_SRC_FILES test*.cpp) + message(SRC_FILES: ${DOTPROD_SRC_FILES}) + set_property(SOURCE + ${DOTPROD_SRC_FILES} + APPEND_STRING PROPERTY + COMPILE_FLAGS " -march=armv8.2-a+dotprod ") +endif() + add_executable( test_linear_8bit_act_xbit_weight test_linear_8bit_act_xbit_weight.cpp diff --git a/torchao/experimental/ops/tests/build_and_run_tests.sh b/torchao/experimental/ops/tests/build_and_run_tests.sh index 082579e20d..4070b9304f 100644 --- a/torchao/experimental/ops/tests/build_and_run_tests.sh +++ b/torchao/experimental/ops/tests/build_and_run_tests.sh @@ -5,20 +5,57 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +target=${1:-"native"} +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) +export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests + IS_ARM64=0 +BUILD_ARM_I8MM=0 +EXTRA_ARGS="" +if [[ "${target}" == "android" ]]; then + if [[ -z ${ANDROID_NDK} ]]; then + echo "Need to set ANDROID_NDK env variable to build for Android"; + exit 1; + fi + android_abi=arm64-v8a + android_platform=28 # must be >=28 for aligned_alloc + IS_ARM64=1 + BUILD_ARM_I8MM=1 # Hardcoded for now + CMAKE_OUT=${CMAKE_OUT/cmake-out/cmake-out-android} + toolchain_file="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" + if [[ -z ${toolchain_file} ]]; then + echo "Unable to find toolchain file at ANDROID_NDK location, looking for ${toolchain_file}" + exit 1; + fi + EXTRA_ARGS="\ + -DCMAKE_TOOLCHAIN_FILE=${toolchain_file} \ + -DANDROID_ABI=${android_abi} \ + -DANDROID_PLATFORM=${android_platform} + " + echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}" +fi + hash arch; retval=$? if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then IS_ARM64=1 fi -export CMAKE_OUT=/tmp/cmake-out/torchao/tests cmake \ - -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ + ${EXTRA_ARGS} \ + -DCMAKE_BUILD_TYPE=Debug \ -DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \ + -DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \ -S . \ -B ${CMAKE_OUT} cmake --build ${CMAKE_OUT} +echo "Successfully built tests." + +if [[ "${target}" != "native" ]]; then + echo "Skip running tests when cross compiling."; + exit 0; +fi + # Run ${CMAKE_OUT}/test_linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/tests/generate_tests.py b/torchao/experimental/ops/tests/generate_tests.py new file mode 100755 index 0000000000..1710a90c49 --- /dev/null +++ b/torchao/experimental/ops/tests/generate_tests.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Simple script to generate test cases for the torchao ops +from string import Template + + +def add_test_string(kernel, m, n, k, g, has_bias, has_clamp): + name = f"m{m}xn{n}xk{k}xg{g}{'_bias' if has_bias else ''}{'_clamp' if has_clamp else ''}" + d = { + "name": name, + "kernel": kernel, + "m": m, + "n": n, + "k": k, + "g": g, + "has_bias": "true" if has_bias else "false", + "has_clamp": "true" if has_clamp else "false", + } + + test_template = Template( + """ +TEST(test_linear_8bit_act_xbit_weight, Kleidi_${kernel}_${name}) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi<${kernel}>(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + $has_bias /*has_bias*/, + $has_clamp /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/$m, /*n=*/$n, /*k=*/$k, /*group_size=*/$g, &ukernel_config); +} +""" + ) + + return [test_template.safe_substitute(d)] + + +def get_test_block(kernel): + # Assuming given kleidi kernel can run with all these test cases + tests = [] + # GEMV, m == 1 + ## subtile + tests += add_test_string(kernel, 1, 2 * 1, 32, 32, False, False) + tests += add_test_string(kernel, 1, 2 * 2, 32, 32, False, False) + tests += add_test_string(kernel, 1, 2 * 3, 32, 32, True, False) + tests += add_test_string(kernel, 1, 2 * 2, 32, 32, True, True) + tests += add_test_string(kernel, 1, 2 * 3, 32, 32, False, True) + ## larger: n - must be multiple of 2 + tests += add_test_string(kernel, 1, 2 * 11, 32, 32, False, False) + tests += add_test_string(kernel, 1, 2 * 13, 32, 32, True, False) + tests += add_test_string(kernel, 1, 2 * 51, 32, 32, False, True) + tests += add_test_string(kernel, 1, 2 * 111, 32, 32, False, False) + ## larger: k, g - must be multiple of 32 + tests += add_test_string(kernel, 1, 2 * 7, 64, 32, False, False) + tests += add_test_string(kernel, 1, 2 * 11, 128, 32, True, False) + tests += add_test_string(kernel, 1, 2 * 13, 64, 64, False, True) + tests += add_test_string(kernel, 1, 2 * 17, 128, 64, False, False) + + # GEMM, m > 1 + ## subtile + tests += add_test_string(kernel, 2, 2 * 1, 32, 32, False, False) + tests += add_test_string(kernel, 2, 2 * 2, 32, 32, False, False) + tests += add_test_string(kernel, 3, 2 * 3, 32, 32, True, False) + tests += add_test_string(kernel, 4, 2 * 4, 32, 32, True, True) + tests += add_test_string(kernel, 3, 2 * 3, 32, 32, False, True) + ## larger: m + tests += add_test_string(kernel, 31, 2 * 1, 32, 32, False, False) + tests += add_test_string(kernel, 32, 2 * 2, 32, 32, False, False) + tests += add_test_string(kernel, 33, 2 * 3, 32, 32, True, False) + tests += add_test_string(kernel, 34, 2 * 4, 32, 32, True, True) + tests += add_test_string(kernel, 35, 2 * 3, 32, 32, False, True) + ## larger: n - must be multiple of 2 + tests += add_test_string(kernel, 7, 2 * 11, 32, 32, False, False) + tests += add_test_string(kernel, 17, 2 * 13, 32, 32, True, False) + tests += add_test_string(kernel, 23, 2 * 51, 32, 32, False, True) + tests += add_test_string(kernel, 41, 2 * 111, 32, 32, False, False) + ## larger: k, g - must be multiple of 32 + tests += add_test_string(kernel, 19, 2 * 7, 64, 32, False, False) + tests += add_test_string(kernel, 23, 2 * 11, 128, 32, True, False) + tests += add_test_string(kernel, 29, 2 * 13, 64, 64, False, True) + tests += add_test_string(kernel, 101, 2 * 17, 128, 64, False, False) + + return "".join(tests) + + +def main(): + kleidi_template = Template( + """ +/*****************/ +// ${kernel} tests +/*****************/ +${prologue} +${tests} +${epilogue} +""" + ) + + kleidi_kernels = [ + "dotprod_1x4x32", + "dotprod_1x8x32", + "i8mm_4x8x32", + "i8mm_8x4x32", + ] + + print("/* Generated by generate_tests.py */") + print("/* Do not modify */") + print() + print("#if defined(TORCHAO_ENABLE_KLEIDI)") + for kernel in kleidi_kernels: + prologue, epilogue = "", "" + if "i8mm" in kernel: + prologue = "#if defined(TORCHAO_ENABLE_ARM_I8MM)" + epilogue = "#endif // TORCHAO_ENABLE_ARM_I8MM" + tests = get_test_block(kernel) + d = { + "prologue": prologue, + "kernel": kernel, + "tests": tests, + "epilogue": epilogue, + } + + print(kleidi_template.safe_substitute(d)) + print("#endif // TORCHAO_ENABLE_KLEIDI") + + +if __name__ == "__main__": + main() diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index 2ed9a71819..932ecac4b2 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -13,18 +13,22 @@ #include #if defined(TORCHAO_ENABLE_KLEIDI) +#include #include +#if defined (TORCHAO_ENABLE_ARM_I8MM) +#include +#include +#endif // TORCHAO_ENABLE_ARM_I8MM #endif // TORCHAO_ENABLE_KLEIDI const float kTol = 1.0e-5; using namespace torchao::ops::linear_8bit_act_xbit_weight; -template +template UKernelConfig get_ukernel_config() { UKernelConfig config; - if constexpr (!has_kleidi) { namespace ukernel = torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; config.mr = 1; @@ -41,40 +45,19 @@ UKernelConfig get_ukernel_config() { &ukernel::prepare_weight_data; config.kernel_fn = &ukernel::kernel; - } else { -#if defined(TORCHAO_ENABLE_KLEIDI) - assert (weight_nbit == 4); - assert (!has_weight_zeros); - - namespace kernel = torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; - - auto uk = kernel::get_ukernel(); - config.mr = uk.get_mr(); - config.nr = uk.get_nr(); - - config.activation_data_size_fn = &kernel::activation_data_size; - config.weight_data_size_fn = &kernel::weight_data_size; - - config.preferred_activation_data_alignment = kernel::get_preferred_alignement(); - config.preferred_weight_data_alignment = kernel::get_preferred_alignement(); - - config.prepare_activation_data_fn = &kernel::prepare_activation_data; - config.prepare_weight_data_fn = &kernel::prepare_weight_data; - - config.kernel_fn = &kernel::kernel; -#else - assert (false); -#endif // TORCHAO_ENABLE_KLEIDI - } return config; } template -void test_linear_8bit_act_xbit_weight(int m, int n, int k, int group_size) { - auto ukernel_config = - get_ukernel_config(); +void test_linear_8bit_act_xbit_weight(int m, int n, int k, int group_size, const UKernelConfig* ukernel_config_arg = nullptr) { + UKernelConfig ukernel_config; + if (ukernel_config_arg != nullptr) { + ukernel_config = *ukernel_config_arg; + } else { + ukernel_config = + get_ukernel_config(); + } auto test_case = torchao:: channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( @@ -159,6 +142,51 @@ void test_linear_8bit_act_xbit_weight(int m, int n, int k, int group_size) { } } +#if defined(TORCHAO_ENABLE_KLEIDI) + +enum kai_kernel_id { + dotprod_1x4x32 = 0, + dotprod_1x8x32, + i8mm_4x8x32, + i8mm_8x4x32 +}; + +#define KAI_GEN_UKERNEL(kernel_ns) \ + namespace kernel = kernel_ns; \ + auto uk = kernel::get_ukernel(); \ + config.mr = uk.get_m_step(); \ + config.nr = uk.get_n_step(); \ + config.activation_data_size_fn = &kernel::activation_data_size; \ + config.weight_data_size_fn = &kernel::weight_data_size; \ + config.preferred_activation_data_alignment = kernel::get_preferred_alignement(); \ + config.preferred_weight_data_alignment = kernel::get_preferred_alignement(); \ + config.prepare_activation_data_fn = &kernel::prepare_activation_data; \ + config.prepare_weight_data_fn = &kernel::prepare_weight_data; \ + config.kernel_fn = &kernel::kernel; \ + +template +UKernelConfig get_ukernel_config_kleidi() { + UKernelConfig config; +#if defined (TORCHAO_ENABLE_ARM_I8MM) + if constexpr (kernel_id == i8mm_4x8x32) { + KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32); + return config; + } + if constexpr (kernel_id == i8mm_8x4x32) { + KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32); + return config; + } +#endif // TORCHAO_ENABLE_ARM_I8MM + if constexpr (kernel_id == dotprod_1x8x32) { + KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32); + return config; + } + KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32); + return config; +} + +#endif // TORCHAO_ENABLE_KLEIDI + TEST(test_linear_8bit_act_xbit_weight, Standard) { test_linear_8bit_act_xbit_weight< 4 /*weight_nbit*/, @@ -263,44 +291,1405 @@ TEST(test_linear_8bit_act_xbit_weight, GroupSizeNotDivisibleBy16) { std::runtime_error); } +// begin +/* Generated by generate_tests.py */ +/* Do not modify */ + #if defined(TORCHAO_ENABLE_KLEIDI) -TEST(test_linear_8bit_act_xbit_weight, KleidiSmall) { + +/*****************/ +// dotprod_1x4x32 tests +/*****************/ + + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn4xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn22xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn26xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn102xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn222xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn14xk64xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn22xk128xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn26xk64xg64_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn34xk128xg64) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m2xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/2, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m2xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/2, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m3xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m4xn8xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/4, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m3xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m31xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/31, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m32xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/, true /*has_kleidi*/>( - /*m=*/1, /*n=*/2, /*k=*/32, /*group_size=*/32); + /*m=*/32, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m33xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/33, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m34xn8xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/34, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m35xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/35, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, KleidiStandard) { +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m7xn22xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/, true /*has_kleidi*/>( - /*m=*/13, /*n=*/20, /*k=*/32, /*group_size=*/32); + /*m=*/7, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m17xn26xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/17, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, KleidiHasClamp) { +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m23xn102xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, true /*has_clamp*/, true /*has_kleidi*/>( - /*m=*/17, /*n=*/10, /*k=*/32 * 2, /*group_size=*/32); + /*m=*/23, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m41xn222xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m19xn14xk64xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/19, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, KleidiHasBias) { +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m23xn22xk128xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/23, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m29xn26xk64xg64_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, true /*has_clamp*/, true /*has_kleidi*/>( - /*m=*/23, /*n=*/18, /*k=*/32 * 3, /*group_size=*/32); + /*m=*/29, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m101xn34xk128xg64) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/101, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); } + + + + +/*****************/ +// dotprod_1x8x32 tests +/*****************/ + + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn4xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn22xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn26xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn102xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn222xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn14xk64xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn22xk128xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn26xk64xg64_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn34xk128xg64) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m2xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/2, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m2xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/2, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m3xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m4xn8xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/4, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m3xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m31xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/31, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m32xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/32, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m33xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/33, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m34xn8xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/34, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m35xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/35, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m7xn22xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/7, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m17xn26xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/17, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m23xn102xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/23, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m41xn222xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m19xn14xk64xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/19, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m23xn22xk128xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/23, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m29xn26xk64xg64_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/29, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m101xn34xk128xg64) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/101, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); +} + + + + +/*****************/ +// i8mm_4x8x32 tests +/*****************/ +#if defined(TORCHAO_ENABLE_ARM_I8MM) + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn4xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn22xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn26xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn102xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn222xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn14xk64xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn22xk128xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn26xk64xg64_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn34xk128xg64) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m2xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/2, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m2xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/2, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m3xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m4xn8xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/4, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m3xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m31xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/31, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m32xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/32, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m33xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/33, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m34xn8xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/34, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m35xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/35, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m7xn22xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/7, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m17xn26xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/17, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m23xn102xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/23, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m41xn222xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m19xn14xk64xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/19, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m23xn22xk128xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/23, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m29xn26xk64xg64_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/29, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m101xn34xk128xg64) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/101, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); +} + +#endif // TORCHAO_ENABLE_ARM_I8MM + + +/*****************/ +// i8mm_8x4x32 tests +/*****************/ +#if defined(TORCHAO_ENABLE_ARM_I8MM) + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn4xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn22xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn26xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn102xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn222xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn14xk64xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn22xk128xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn26xk64xg64_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn34xk128xg64) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/1, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m2xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/2, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m2xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/2, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m3xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m4xn8xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/4, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m3xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m31xn2xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/31, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m32xn4xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/32, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m33xn6xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/33, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m34xn8xk32xg32_bias_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/34, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m35xn6xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/35, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m7xn22xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/7, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m17xn26xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/17, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m23xn102xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/23, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m41xn222xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m19xn14xk64xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/19, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m23xn22xk128xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/23, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m29xn26xk64xg64_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + true /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/29, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m101xn34xk128xg64) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, + false /*has_clamp*/, + true /*has_kleidi*/>( + /*m=*/101, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); +} + +#endif // TORCHAO_ENABLE_ARM_I8MM + #endif // TORCHAO_ENABLE_KLEIDI From 463a87274f196f7c6cc16f9761940e29b3d123db Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Wed, 29 Jan 2025 20:44:26 -0800 Subject: [PATCH 043/115] skip failing MX tests on cuda capability 10.0 (#1624) Update [ghstack-poisoned] --- test/prototype/mx_formats/test_custom_cast.py | 8 +++++++- test/prototype/mx_formats/test_mx_linear.py | 12 +++++++++++- test/prototype/mx_formats/test_mx_tensor.py | 11 ++++++++++- torchao/utils.py | 9 +++++++++ 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index 6f9a76cf19..d27e1831c9 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -40,7 +40,7 @@ sem_vals_to_f32, ) from torchao.prototype.mx_formats.mx_tensor import MXTensor -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100 torch.manual_seed(0) @@ -310,6 +310,9 @@ def test_fp4_pack_unpack(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) def test_fp4_triton_unscaled_cast(): packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda") f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals)) @@ -320,6 +323,9 @@ def test_fp4_triton_unscaled_cast(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) def test_fp4_triton_scaled_cast(): size = (256,) orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100 diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index d280e38c36..35afeb7959 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -18,7 +18,11 @@ swap_linear_with_mx_linear, ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_4, + is_sm_at_least_89, + is_sm_at_least_100, +) torch.manual_seed(2) @@ -99,6 +103,9 @@ def test_activation_checkpointing(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("bias", [False, True]) # TODO(future PR): figure out why torch.compile does not match eager when @@ -184,6 +191,9 @@ def test_inference_linear(elem_dtype, bias, input_shape): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_inference_compile_simple(elem_dtype): """ diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index ae87ee021e..21cb49c064 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -21,7 +21,11 @@ to_dtype, ) from torchao.quantization.utils import compute_error -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_4, + is_sm_at_least_89, + is_sm_at_least_100, +) torch.manual_seed(2) @@ -166,6 +170,8 @@ def test_transpose(elem_dtype, fp4_triton): """ if elem_dtype != DTYPE_FP4 and fp4_triton: pytest.skip("unsupported configuration") + elif fp4_triton and is_sm_at_least_100(): + pytest.skip("triton does not work yet on CUDA capability 10.0") M, K = 128, 256 block_size = 32 @@ -205,6 +211,9 @@ def test_view(elem_dtype): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0" +) @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("all_zeros", [False, True]) diff --git a/torchao/utils.py b/torchao/utils.py index 7a17c1b104..f67463f9f7 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -630,6 +630,15 @@ def is_sm_at_least_90(): ) +# TODO(future PR): rename to 8_9, 9_0, 10_0 instead of 89, 10, 100 +def is_sm_at_least_100(): + return ( + torch.cuda.is_available() + and torch.version.cuda + and torch.cuda.get_device_capability() >= (10, 0) + ) + + TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") From 7815262d77ccbd3b56ec9cf4040f3209303c0a4c Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Thu, 30 Jan 2025 16:56:49 +0000 Subject: [PATCH 044/115] [Feat]: Add support for kleidiai quantization schemes (#1447) --- torchao/experimental/docs/readme.md | 31 ++++ ...8_dynamic_activation_intx_weight_layout.py | 147 +++++++++++++++++- torchao/experimental/quant_api.py | 100 ++++++++---- ...tivation_intx_weight_layout_target_aten.py | 84 ++++++++++ torchao/quantization/quant_api.py | 4 +- 5 files changed, 328 insertions(+), 38 deletions(-) create mode 100644 torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py diff --git a/torchao/experimental/docs/readme.md b/torchao/experimental/docs/readme.md index 7f0970f792..a178c9b328 100644 --- a/torchao/experimental/docs/readme.md +++ b/torchao/experimental/docs/readme.md @@ -98,6 +98,37 @@ quantize_( ) ``` +KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows: + +```python +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, +) +from torchao.quantization.granularity import ( + PerGroup, + PerRow, +) +from torchao.quantization.quant_api import quantize_ +from torchao.quantization.quant_primitives import MappingType + +my_model = Model() + +quantize_( + my_model, + int8_dynamic_activation_intx_weight( + weight_dtype=torch.int4, + granularity=PerGroup(32), # PerRow() is also supported + has_weight_zeros=True, # Should be True + weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR # MappingType.SYMMETRIC can also be used but increases error + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"), + ), +) +``` + If you get stuck, consult `torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py` for a working example. diff --git a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py index 7b2b1da145..9d42596793 100644 --- a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -5,12 +5,15 @@ # LICENSE file in the root directory of this source tree. import logging +from enum import Enum, auto from typing import Optional, Tuple import torch from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + get_tensor_impl_constructor, register_layout, ) from torchao.dtypes.affine_quantized_tensor_ops import ( @@ -19,6 +22,13 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout from torchao.quantization.quant_primitives import ( ZeroPointDomain, + MappingType, + choose_qparams_affine, + quantize_affine, +) + +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, ) logger = logging.getLogger(__name__) @@ -31,17 +41,33 @@ handler.setFormatter(formatter) logger.addHandler(handler) +class Target(Enum): + """Enum that indicates the backend target""" + + NATIVE = auto() + ATEN = auto() + +def target_from_str(target: str) -> Target: + if target.lower() == "native": + return Target.NATIVE + elif target.lower() == "aten": + return Target.ATEN + else: + raise ValueError(f"Invalid target: {target}") class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout): bit_width: Optional[int] group_size: Optional[int] has_weight_zeros: Optional[bool] + # The target platform for the layout, 'native' or 'aten' + target: Optional[Target] def __init__( self, bit_width: Optional[int] = None, group_size: Optional[int] = None, has_weight_zeros: Optional[bool] = None, + target: Optional[str] = "native", ): if bit_width is not None: assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8" @@ -51,6 +77,7 @@ def __init__( self.bit_width = bit_width self.group_size = group_size self.has_weight_zeros = has_weight_zeros + self.target = target_from_str(target) if not self.has_params_set(): assert ( @@ -60,13 +87,14 @@ def __init__( ), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False" def extra_repr(self): - return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}" + return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}" def has_params_set(self) -> bool: return ( (self.bit_width is not None) and (self.group_size is not None) and (self.has_weight_zeros is not None) + and (self.target is not None) ) @@ -125,9 +153,11 @@ def from_plain( scale: torch.Tensor, zero_point: Optional[torch.Tensor], layout: Layout, + bias: Optional[torch.Tensor] = None, ): assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain" + assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}" # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor # when AOTI supports int @@ -136,6 +166,13 @@ def from_plain( n_tensor = torch.empty(0, n, dtype=torch.int8) k_tensor = torch.empty(0, k, dtype=torch.int8) + if layout.target == Target.ATEN: + assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0" + int_data = int_data.add(8) + int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8) + packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n) + return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor) + if layout.has_weight_zeros: args = [ int_data.to(torch.int8), @@ -211,16 +248,13 @@ def __tensor_unflatten__( def _linear_check(input_tensor, weight_tensor, bias): layout = weight_tensor.tensor_impl.get_layout() return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and ( - bias is None + bias is None or layout.target == Target.ATEN # Aten target allows bias ) def _linear_impl(input_tensor, weight_tensor, bias): - assert ( - bias is None - ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl" - def _impl_2d(input_tensor, weight_tensor): + def _impl_2d_native(input_tensor, weight_tensor): assert input_tensor.dim() == 2 assert weight_tensor.dim() == 2 @@ -255,6 +289,31 @@ def _impl_2d(input_tensor, weight_tensor): torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight" )(*args) + def _impl_2d_aten(input_tensor, weight_tensor): + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + group_size = weight_tensor.tensor_impl.get_layout().group_size + packed_weight = weight_tensor.tensor_impl.packed_weight + return torch.ops.aten._dyn_quant_matmul_4bit( + input_tensor, packed_weight, group_size, k_, n) + + target = weight_tensor.tensor_impl.get_layout().target + + if target == Target.ATEN: + assert ( + TORCH_VERSION_AT_LEAST_2_6 == 1 + ), "Target.ATEN requires torch >= 2.6.0" + _impl_2d = _impl_2d_aten + elif target == Target.NATIVE: + _impl_2d = _impl_2d_native + assert ( + bias is None + ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' " + if input_tensor.dim() == 2: return _impl_2d(input_tensor, weight_tensor) @@ -268,8 +327,82 @@ def _impl_2d(input_tensor, weight_tensor): res = res.reshape(*lead_shape, m, n) return res - register_aqt_quantized_linear_dispatch( _linear_check, _linear_impl, ) + + +class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor): + """ + PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class. + """ + + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + _layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(), + use_hqq: bool = False, + bias: Optional[torch.Tensor] = None + ): + assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization" + assert isinstance( + _layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}" + assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'." + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + + scale, zero_point = choose_qparams_affine( + input_float, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None + # TODO should probably consolidate ZeroPointDomain.NONE and None + if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: + zero_point = None + data = quantize_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + # Note: output will be uint8 tensor for sub byte tensors for now + + data = _layout.post_process(data) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + +to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index 4e0906d0a0..e77d09d98b 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import sys import logging from typing import Optional, Union @@ -18,14 +19,18 @@ PerGroup, PerRow, ) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, +) +from torchao.dtypes import PlainLayout logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) -import sys handler = logging.StreamHandler(sys.stdout) -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) @@ -489,6 +494,8 @@ def quantize(self, model: nn.Module) -> nn.Module: from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( PackedLinearInt8DynamicActivationIntxWeightLayout, + to_packedlinearint8dynamicactivationintxweight_quantized_intx, + Target, ) from torchao.quantization.linear_activation_quantized_tensor import ( to_linear_activation_quantized, @@ -508,7 +515,7 @@ def int8_dynamic_activation_intx_weight( has_weight_zeros: bool = False, weight_mapping_type=MappingType.ASYMMETRIC, act_mapping_type=MappingType.ASYMMETRIC, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() also works, but will be slow + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="native"), # PlainLayout() also works, but will be slow ): """ Dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers. @@ -531,13 +538,25 @@ def int8_dynamic_activation_intx_weight( - The weight tensor must have dtype=float32 (note that after applying quantization, the weights will no longer be float32) - act_mapping_type must be MappingType.ASYMMETRIC """ - try: - torch.ops.torchao._pack_8bit_act_4bit_weight - except AttributeError: - raise Exception( - "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." - + " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance." - ) + + def is_torchao_op_skippable(layout): + return ( + isinstance(layout, PlainLayout) or + ( + isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and + layout.target == Target.ATEN + ) + ) + + if not is_torchao_op_skippable(layout): + try: + torch.ops.torchao._pack_8bit_act_4bit_weight + except AttributeError: + raise Exception( + "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." + + " You can also set target to 'aten' if you are using ARM CPU." + + " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance." + ) dtype_to_bit_width = { torch.int1: 1, @@ -555,8 +574,9 @@ def int8_dynamic_activation_intx_weight( ) bit_width = dtype_to_bit_width[weight_dtype] layout_arg = layout + propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target == Target.ATEN - def apply(weight): + def apply(weight, bias: Optional[torch.Tensor] = None): if isinstance(granularity, PerGroup): group_size = granularity.group_size elif isinstance(granularity, PerRow): @@ -569,6 +589,11 @@ def apply(weight): assert weight.shape[-1] % group_size == 0 layout = layout_arg + scale_dtype = None + tensor_quantizer = to_affine_quantized_intx + quant_min = -(1 << (bit_width - 1)) + quant_max = (1 << (bit_width - 1)) - 1 + if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): assert ( weight.device == torch.device("cpu") @@ -584,25 +609,40 @@ def apply(weight): bit_width=bit_width, group_size=group_size, has_weight_zeros=has_weight_zeros, + target="aten" if layout.target == Target.ATEN else "native", ) - - quant_min = -(1 << (bit_width - 1)) - quant_max = (1 << (bit_width - 1)) - 1 - weight = to_affine_quantized_intx( - weight, - mapping_type=weight_mapping_type, - block_size=(1, group_size), - target_dtype=torch.int32, - quant_min=quant_min, - quant_max=quant_max, - eps=torch.finfo(torch.float32).eps, - zero_point_dtype=torch.int8, - preserve_zero=has_weight_zeros, - zero_point_domain=ZeroPointDomain.INT - if has_weight_zeros - else ZeroPointDomain.NONE, - _layout=layout, - ) + if layout.target == Target.ATEN: + if weight_dtype != torch.int4 or \ + has_weight_zeros != True or \ + weight_mapping_type == MappingType.ASYMMETRIC: + raise NotImplementedError( + f"target 'aten' requires:\n" + f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n" + f"- has_weight_zeros to be True,\n" + f"- weight_dtype to be torch.int4,\n" + f"- weight_mapping_type to be MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR" + ) + assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0" + if torch.backends.kleidiai.is_available(): + if isinstance(granularity, PerGroup): + scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype + tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx + + quantizer_args = [weight, + weight_mapping_type, + (1, group_size), + torch.int32, + quant_min, + quant_max, + torch.finfo(torch.float32).eps, + scale_dtype, + torch.int8, + has_weight_zeros, + ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE, + layout, + False] + ([bias] if propagate_bias else []) + + weight = tensor_quantizer(*quantizer_args) # Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused # with the kernel and it should not be applied separately @@ -620,7 +660,7 @@ def apply(weight): weight = to_linear_activation_quantized(weight, activation_quant_func) return weight - return _get_linear_subclass_inserter(apply) + return _get_linear_subclass_inserter(apply, propagate_bias=propagate_bias) class UIntxWeightOnlyQuantizedLinear(nn.Module): diff --git a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py new file mode 100644 index 0000000000..c1c5ed771e --- /dev/null +++ b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch + +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, +) +from torchao.quantization.granularity import ( + PerGroup, + PerRow, +) +from torchao.quantization.quant_api import quantize_ +from torchao.utils import unwrap_tensor_subclass +from torchao.quantization.quant_primitives import MappingType + + +class TestPackedLinearInt8DynamicActivationIntxWeightLayoutAten(unittest.TestCase): + def test_accuracy(self): + """ + Checks the accuracy of PackedLinearInt8DynamicActivationIntxWeightLayout() by comparing + its results to the results of a reference model that uses PlainLayout() + """ + granularities = [PerRow()] + m = 32 + n = 128 + k = 256 + activations = torch.randn(m, k) + weight_mapping_type = MappingType.SYMMETRIC_NO_CLIPPING_ERR + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + for weight_dtype in [ + torch.int4, + ]: + for has_weight_zeros in [True]: + for granularity in granularities: + print( + f"Testing weight_dtype={weight_dtype}, has_weight_zeros={ + has_weight_zeros}, granularity={granularity}" + ) + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + weight_mapping_type=weight_mapping_type, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout( + target="aten"), # default + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=PlainLayout(), + ), + ) + + with torch.no_grad(): + res = quantized_model(activations) + ref = quantized_model_reference(activations) + + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.04) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 02af4ced91..bbe9b1cb6b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -450,13 +450,15 @@ def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" -def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, **kwargs): +def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs): """Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs) to the weight of linear module """ def insert_subclass(lin): requires_grad = allow_requires_grad and lin.weight.requires_grad + if propagate_bias == True: + kwargs["bias"] = lin.bias lin.weight = torch.nn.Parameter( constructor(lin.weight, **kwargs), requires_grad=requires_grad ) From 48fdd310b3977a0db2ceba37a7725192cd2aafd4 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 30 Jan 2025 13:21:55 -0800 Subject: [PATCH 045/115] Ruff lint (#1646) lint --- ...8_dynamic_activation_intx_weight_layout.py | 50 +++++++---- torchao/experimental/quant_api.py | 89 +++++++++++-------- ...tivation_intx_weight_layout_target_aten.py | 5 +- 3 files changed, 85 insertions(+), 59 deletions(-) diff --git a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py index 9d42596793..d4e6284ffc 100644 --- a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -21,12 +21,11 @@ ) from torchao.dtypes.utils import AQTTensorImpl, Layout from torchao.quantization.quant_primitives import ( - ZeroPointDomain, MappingType, + ZeroPointDomain, choose_qparams_affine, quantize_affine, ) - from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_6, ) @@ -41,12 +40,14 @@ handler.setFormatter(formatter) logger.addHandler(handler) + class Target(Enum): """Enum that indicates the backend target""" NATIVE = auto() ATEN = auto() + def target_from_str(target: str) -> Target: if target.lower() == "native": return Target.NATIVE @@ -55,6 +56,7 @@ def target_from_str(target: str) -> Target: else: raise ValueError(f"Invalid target: {target}") + class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout): bit_width: Optional[int] group_size: Optional[int] @@ -157,7 +159,10 @@ def from_plain( ): assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain" - assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}" + assert layout.target in { + Target.NATIVE, + Target.ATEN, + }, f"Unexpected target: {layout.target}" # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor # when AOTI supports int @@ -167,10 +172,14 @@ def from_plain( k_tensor = torch.empty(0, k, dtype=torch.int8) if layout.target == Target.ATEN: - assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0" + assert ( + TORCH_VERSION_AT_LEAST_2_6 + ), "aten target is requires torch version > 2.6.0" int_data = int_data.add(8) - int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8) - packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n) + int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) + packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight( + int_data, scale, bias, layout.group_size, k, n + ) return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor) if layout.has_weight_zeros: @@ -248,12 +257,11 @@ def __tensor_unflatten__( def _linear_check(input_tensor, weight_tensor, bias): layout = weight_tensor.tensor_impl.get_layout() return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and ( - bias is None or layout.target == Target.ATEN # Aten target allows bias + bias is None or layout.target == Target.ATEN # Aten target allows bias ) def _linear_impl(input_tensor, weight_tensor, bias): - def _impl_2d_native(input_tensor, weight_tensor): assert input_tensor.dim() == 2 assert weight_tensor.dim() == 2 @@ -299,14 +307,13 @@ def _impl_2d_aten(input_tensor, weight_tensor): group_size = weight_tensor.tensor_impl.get_layout().group_size packed_weight = weight_tensor.tensor_impl.packed_weight return torch.ops.aten._dyn_quant_matmul_4bit( - input_tensor, packed_weight, group_size, k_, n) + input_tensor, packed_weight, group_size, k_, n + ) target = weight_tensor.tensor_impl.get_layout().target if target == Target.ATEN: - assert ( - TORCH_VERSION_AT_LEAST_2_6 == 1 - ), "Target.ATEN requires torch >= 2.6.0" + assert TORCH_VERSION_AT_LEAST_2_6 == 1, "Target.ATEN requires torch >= 2.6.0" _impl_2d = _impl_2d_aten elif target == Target.NATIVE: _impl_2d = _impl_2d_native @@ -327,6 +334,7 @@ def _impl_2d_aten(input_tensor, weight_tensor): res = res.reshape(*lead_shape, m, n) return res + register_aqt_quantized_linear_dispatch( _linear_check, _linear_impl, @@ -354,12 +362,17 @@ def from_hp_to_intx( zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, _layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(), use_hqq: bool = False, - bias: Optional[torch.Tensor] = None + bias: Optional[torch.Tensor] = None, ): - assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization" + assert ( + use_hqq == False + ), "PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization" assert isinstance( - _layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}" - assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'." + _layout, PackedLinearInt8DynamicActivationIntxWeightLayout + ), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}" + assert ( + _layout.target == Target.ATEN + ), "PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'." original_shape = input_float.shape input_float = _layout.pre_process(input_float) @@ -405,4 +418,7 @@ def from_hp_to_intx( dtype=input_float.dtype, ) -to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx + +to_packedlinearint8dynamicactivationintxweight_quantized_intx = ( + PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx +) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index e77d09d98b..ea89e98303 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -4,8 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import sys import logging +import sys from typing import Optional, Union import torch @@ -15,6 +15,7 @@ quantize_per_channel_group, ) +from torchao.dtypes import PlainLayout from torchao.quantization.granularity import ( PerGroup, PerRow, @@ -22,15 +23,13 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_6, ) -from torchao.dtypes import PlainLayout logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) handler = logging.StreamHandler(sys.stdout) -formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s") +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) @@ -494,8 +493,8 @@ def quantize(self, model: nn.Module) -> nn.Module: from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( PackedLinearInt8DynamicActivationIntxWeightLayout, - to_packedlinearint8dynamicactivationintxweight_quantized_intx, Target, + to_packedlinearint8dynamicactivationintxweight_quantized_intx, ) from torchao.quantization.linear_activation_quantized_tensor import ( to_linear_activation_quantized, @@ -515,7 +514,9 @@ def int8_dynamic_activation_intx_weight( has_weight_zeros: bool = False, weight_mapping_type=MappingType.ASYMMETRIC, act_mapping_type=MappingType.ASYMMETRIC, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="native"), # PlainLayout() also works, but will be slow + layout=PackedLinearInt8DynamicActivationIntxWeightLayout( + target="native" + ), # PlainLayout() also works, but will be slow ): """ Dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers. @@ -540,13 +541,10 @@ def int8_dynamic_activation_intx_weight( """ def is_torchao_op_skippable(layout): - return ( - isinstance(layout, PlainLayout) or - ( - isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and - layout.target == Target.ATEN - ) - ) + return isinstance(layout, PlainLayout) or ( + isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) + and layout.target == Target.ATEN + ) if not is_torchao_op_skippable(layout): try: @@ -574,7 +572,10 @@ def is_torchao_op_skippable(layout): ) bit_width = dtype_to_bit_width[weight_dtype] layout_arg = layout - propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target == Target.ATEN + propagate_bias = ( + isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) + and layout_arg.target == Target.ATEN + ) def apply(weight, bias: Optional[torch.Tensor] = None): if isinstance(granularity, PerGroup): @@ -612,35 +613,45 @@ def apply(weight, bias: Optional[torch.Tensor] = None): target="aten" if layout.target == Target.ATEN else "native", ) if layout.target == Target.ATEN: - if weight_dtype != torch.int4 or \ - has_weight_zeros != True or \ - weight_mapping_type == MappingType.ASYMMETRIC: + if ( + weight_dtype != torch.int4 + or has_weight_zeros != True + or weight_mapping_type == MappingType.ASYMMETRIC + ): raise NotImplementedError( - f"target 'aten' requires:\n" - f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n" - f"- has_weight_zeros to be True,\n" - f"- weight_dtype to be torch.int4,\n" - f"- weight_mapping_type to be MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR" + "target 'aten' requires:\n" + "- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n" + "- has_weight_zeros to be True,\n" + "- weight_dtype to be torch.int4,\n" + "- weight_mapping_type to be MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR" ) - assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0" + assert ( + TORCH_VERSION_AT_LEAST_2_6 + ), "aten target is requires torch version > 2.6.0" if torch.backends.kleidiai.is_available(): if isinstance(granularity, PerGroup): - scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype - tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx - - quantizer_args = [weight, - weight_mapping_type, - (1, group_size), - torch.int32, - quant_min, - quant_max, - torch.finfo(torch.float32).eps, - scale_dtype, - torch.int8, - has_weight_zeros, - ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE, - layout, - False] + ([bias] if propagate_bias else []) + scale_dtype = ( + torch.bfloat16 + ) # KleidiAI kernel requires bfloat16 scale_dtype + tensor_quantizer = ( + to_packedlinearint8dynamicactivationintxweight_quantized_intx + ) + + quantizer_args = [ + weight, + weight_mapping_type, + (1, group_size), + torch.int32, + quant_min, + quant_max, + torch.finfo(torch.float32).eps, + scale_dtype, + torch.int8, + has_weight_zeros, + ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE, + layout, + False, + ] + ([bias] if propagate_bias else []) weight = tensor_quantizer(*quantizer_args) diff --git a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py index c1c5ed771e..2a08d0e548 100644 --- a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py +++ b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py @@ -17,11 +17,9 @@ int8_dynamic_activation_intx_weight, ) from torchao.quantization.granularity import ( - PerGroup, PerRow, ) from torchao.quantization.quant_api import quantize_ -from torchao.utils import unwrap_tensor_subclass from torchao.quantization.quant_primitives import MappingType @@ -57,7 +55,8 @@ def test_accuracy(self): has_weight_zeros=has_weight_zeros, weight_mapping_type=weight_mapping_type, layout=PackedLinearInt8DynamicActivationIntxWeightLayout( - target="aten"), # default + target="aten" + ), # default ), ) From 3eb18e771bc7e830a2e56002407256052d8c5e7d Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 30 Jan 2025 20:06:25 -0800 Subject: [PATCH 046/115] float8 rowwise training: add FSDP workaround (#1629) Summary: Adds the workaround from https://github.com/pytorch/pytorch/issues/141881 to the torchao float8 rowwise recipe, to reduce memory usage when FSDP is on. Test Plan: tested in torchtitan, LLaMa 3 8B 8H100 training with rowwise peak memory decreased from 67GiB to 59GiB Reviewers: Subscribers: Tasks: Tags: --- torchao/float8/float8_linear.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 18aebaeada..6b3c0f06df 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -159,6 +159,15 @@ def backward(ctx, grad_output): elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED: weight_t_maybe_fp8_dim0 = weight_hp_t else: + if ( + c.cast_config_weight_for_grad_input.scaling_granularity + is ScalingGranularity.AXISWISE + ): + # workaround from https://github.com/pytorch/pytorch/issues/141881 + # to avoid saving float8 weight from forward to backward when + # FSDP is on + weight_hp_t = weight_hp_t + (grad_output_reshaped[0, 0] * 0) + # Note: we need https://github.com/pytorch/pytorch/issues/136267 # to be solved to have a chance to reuse max(abs(weight, dim=...)) # from the forward to get max(abs(weight)) here without reading From 122eb73a90ec4821fc02f82abad295fc5aa2a6a1 Mon Sep 17 00:00:00 2001 From: ngc92 <7938269+ngc92@users.noreply.github.com> Date: Sat, 1 Feb 2025 17:29:42 +0200 Subject: [PATCH 047/115] more stringent test for CPUOffloadOptimizer (#1650) * more stringent test for CPUOffloadOptimizer * fix missing sync --- test/prototype/test_low_bit_optim.py | 32 ++++++++++++++++--- .../prototype/low_bit_optim/cpu_offload.py | 2 ++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index acc7576e56..562a78c347 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -260,11 +260,24 @@ def test_optim_4bit_correctness(self, optim_name): @parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)]) def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): device = _DEVICES[-1] - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) + # The first two layers are chosen so that they have a terrible arithmetic density. + # this means long transfers and comparatively quick computation, increasing the chances + # that missing synchronization will lead to test failures. + # The third layer is very small, here to validate non-trainable parameters, + # but shouldn't influence the timings + model1 = nn.Sequential( + nn.Linear(32, 131072), + nn.ReLU(), + nn.Linear(131072, 64), + nn.ReLU(), + nn.Linear(64, 64), + nn.ReLU(), + nn.Linear(64, 128), + ) model1.to(device) # make sure it can work in the presence of non-trainable params - model1[0].requires_grad_(False) + model1[2].requires_grad_(False) model2 = copy.deepcopy(model1) optim1 = torch.optim.AdamW(model1.parameters()) @@ -274,15 +287,26 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): offload_gradients=offload_grad, ) + rng = torch.Generator(device=device) + rng.manual_seed(42) + + # make sure to run both models separately; otherwise, model1 gives additional + # time for operations in model2 to complete, marking potential race conditions. for _ in range(2): for _ in range(grad_accum): - x = torch.randn(4, 32, device=device) + x = torch.randn(4, 32, device=device, generator=rng) model1(x).sum().backward() - model2(x).sum().backward() optim1.step() optim1.zero_grad() + # reset the rng + rng.manual_seed(42) + for _ in range(2): + for _ in range(grad_accum): + x = torch.randn(4, 32, device=device, generator=rng) + model2(x).sum().backward() + optim2.step() optim2.zero_grad() diff --git a/torchao/prototype/low_bit_optim/cpu_offload.py b/torchao/prototype/low_bit_optim/cpu_offload.py index 90008f67fe..ccdd584066 100644 --- a/torchao/prototype/low_bit_optim/cpu_offload.py +++ b/torchao/prototype/low_bit_optim/cpu_offload.py @@ -107,6 +107,8 @@ def step(self, closure=None): with getattr(torch, self.device).stream(self.stream): p_device.copy_(p_host, non_blocking=True) + # make sure param H2D finishes before the next forward pass + self.stream.synchronize() self.queue.clear() return loss From 6ffe2360a7382c51b9a5a5ab30fb7aeb4b98963d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 2 Feb 2025 19:36:01 +0700 Subject: [PATCH 048/115] Fix LR scheduler issue with CPU offload optimizer (#1649) * synchronize param H2D * let CPU offload inherits Optimizer * add scheduler to test --- test/prototype/test_low_bit_optim.py | 5 +++++ torchao/prototype/low_bit_optim/cpu_offload.py | 6 +++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 562a78c347..d7d6fe7dc8 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -287,6 +287,9 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): offload_gradients=offload_grad, ) + scheduler1 = torch.optim.lr_scheduler.CosineAnnealingLR(optim1, 100) + scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim2, 100) + rng = torch.Generator(device=device) rng.manual_seed(42) @@ -299,6 +302,7 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): optim1.step() optim1.zero_grad() + scheduler1.step() # reset the rng rng.manual_seed(42) @@ -309,6 +313,7 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): optim2.step() optim2.zero_grad() + scheduler2.step() for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) diff --git a/torchao/prototype/low_bit_optim/cpu_offload.py b/torchao/prototype/low_bit_optim/cpu_offload.py index ccdd584066..b94340a32a 100644 --- a/torchao/prototype/low_bit_optim/cpu_offload.py +++ b/torchao/prototype/low_bit_optim/cpu_offload.py @@ -6,7 +6,11 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, get_available_devices -class CPUOffloadOptimizer: +# NOTE: We make this inherit Optimizer so it works with PyTorch's built-in LR +# schedulers. (those schedulers specifically check for instances of Optimizer). +# However, it won't behave exactly like Optimizer e.g. we don't call +# Optimizer.__init__(), there is no self.defaults. +class CPUOffloadOptimizer(Optimizer): def __init__( self, params: ParamsT, From 7e546292ad404251002fed7aa3b62245d2a6098e Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Mon, 3 Feb 2025 16:46:54 -0800 Subject: [PATCH 049/115] Fix ruff and make sure pre-commit is at same version (#1658) stack-info: PR: https://github.com/pytorch/ao/pull/1658, branch: drisspg/stack/32 --- .pre-commit-config.yaml | 2 +- torchao/quantization/quant_api.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3e34f1d465..79824e1061 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.5.6 + rev: v0.6.8 hooks: # Run the linter. - id: ruff diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bbe9b1cb6b..7154957a21 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -450,7 +450,9 @@ def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" -def _get_linear_subclass_inserter(constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs): +def _get_linear_subclass_inserter( + constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs +): """Helper function to apply the constructor that quantizes the weight Tensor (with additional kwargs) to the weight of linear module """ From b2fb664f4be31170376d6b3594037e29b21947bf Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 4 Feb 2025 09:58:22 -0800 Subject: [PATCH 050/115] Add int8 dynamic activation + int8 weight only test to TensorParallel (#1657) --- .../dtypes/test_affine_quantized_tensor_parallel.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 3abb736f92..76b6b74a3d 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -13,6 +13,7 @@ float8_dynamic_activation_float8_weight, float8_weight_only, int4_weight_only, + int8_dynamic_activation_int8_weight, int8_weight_only, ) from torchao.quantization.observer import PerRow, PerTensor @@ -166,9 +167,21 @@ def test_tp_gemlite(self, dtype): return self._test_tp(dtype) +class TestInt8dqAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel): + QUANT_METHOD_FN = staticmethod(int8_dynamic_activation_int8_weight) + COMMON_DTYPES = [torch.bfloat16] + + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tp(self, dtype): + return self._test_tp(dtype) + + common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel) +common_utils.instantiate_parametrized_tests(TestInt8dqAffineQuantizedTensorParallel) # Run only on H100 if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): From 1a4c8f93c404d531e97de6c2328e857354dd0f44 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 5 Feb 2025 10:53:05 +0800 Subject: [PATCH 051/115] Add CUTLASS-based W4A4 (#1515) * add w4a4 * add test * hook up to AQT * fix quant api test * fix test * make threadblockswizzle a template param * re-use s8s4 cutlass template * add Alex's patch and some changes * fix aqt test * remove int4_cutlass.cu * apply alex's patch * update benchmark script * ruff * add some tuning * reduce num_stages to fit shared memory of small GPUs (<100kb) * replace torch timer with triton do_bench * ruff * use ZeroPointDomain.NONE * fix 3.7 typing * merge Aleksandar changes * run ruff * try replace torch/extension.h with torch/library.h * (alexsamardzic) improve error handling * ruff format * add note on cutlass naming --- ...benchmark_rowwise_scaled_linear_cutlass.py | 70 +++ benchmarks/benchmark_s8s4_cutlass.py | 52 -- setup.py | 36 +- test/dtypes/test_affine_quantized.py | 2 + test/test_rowwise_scaled_linear_cutlass.py | 104 ++++ test/test_s8s4_linear_cutlass.py | 77 --- torchao/csrc/README.md | 3 +- torchao/csrc/cuda/cutlass_extensions/common.h | 34 ++ .../rowwise_scaled_linear_cutlass/README.md | 52 ++ .../rowwise_scaled_linear_cutlass.cuh} | 456 +++++++++--------- .../rowwise_scaled_linear_cutlass_s4s4.cu | 28 ++ .../rowwise_scaled_linear_cutlass_s8s4.cu | 28 ++ torchao/dtypes/affine_quantized_tensor_ops.py | 6 + .../uintx/cutlass_int4_packed_layout.py | 43 +- torchao/ops.py | 117 ++--- torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 68 +++ 17 files changed, 734 insertions(+), 444 deletions(-) create mode 100644 benchmarks/benchmark_rowwise_scaled_linear_cutlass.py delete mode 100644 benchmarks/benchmark_s8s4_cutlass.py create mode 100644 test/test_rowwise_scaled_linear_cutlass.py delete mode 100644 test/test_s8s4_linear_cutlass.py create mode 100644 torchao/csrc/cuda/cutlass_extensions/common.h create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_cutlass/README.md rename torchao/csrc/cuda/{s8s4_linear_cutlass/s8s4_linear_cutlass.cu => rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh} (53%) create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu create mode 100644 torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu diff --git a/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py new file mode 100644 index 0000000000..c4c9c099be --- /dev/null +++ b/benchmarks/benchmark_rowwise_scaled_linear_cutlass.py @@ -0,0 +1,70 @@ +import pandas as pd +import torch +from tqdm import tqdm +from triton.testing import do_bench + +from torchao.ops import ( + rowwise_scaled_linear_cutlass_s4s4, + rowwise_scaled_linear_cutlass_s8s4, +) + + +def benchmark_microseconds(f, *args): + return do_bench(lambda: f(*args), return_mode="median") * 1e3 + + +def get_problem(m: int, n: int, k: int, A_nbits: int, B_nbits: int): + assert A_nbits in (4, 8) and B_nbits in (4, 8) + + dev = torch.device("cuda") + A = torch.randint(-128, 127, (m, k * A_nbits // 8), dtype=torch.int8, device=dev) + A_scale = torch.randn((m,), dtype=torch.half, device=dev) + B = torch.randint( + -128, 127, size=(n, k * B_nbits // 8), dtype=torch.int8, device=dev + ) + B_scale = torch.randn((n,), dtype=torch.half, device=dev) + C = None + + return A, A_scale, B, B_scale, C + + +def benchmark(m: int, k: int, n: int): + dev = torch.device("cuda") + A_ref = torch.randn((m, k), dtype=torch.half, device=dev) + B_ref = torch.randn((n, k), dtype=torch.half, device=dev) + fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref) + + A, A_scale, B, B_scale, C = get_problem(m, n, k, 8, 4) + rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds( + rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C + ) + + A, A_scale, B, B_scale, C = get_problem(m, n, k, 4, 4) + rowwise_scaled_linear_cutlass_s4s4_time = benchmark_microseconds( + rowwise_scaled_linear_cutlass_s4s4, A, A_scale, B, B_scale, C + ) + + return { + "m": m, + "k": k, + "n": n, + "fp16_latency (ms)": fp16_time, + "rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time, + "s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time, + "rowwise_scaled_linear_cutlass_s4s4 latency (ms)": rowwise_scaled_linear_cutlass_s4s4_time, + "s4s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s4s4_time, + } + + +if __name__ == "__main__": + k_vals = (8192, 8192, 8192, 28672) + n_vals = (8192, 10240, 57344, 8192) + + results = [] + for m in tqdm([1 << i for i in range(10)]): + for n, k in zip(n_vals, k_vals): + results.append(benchmark(m, k, n)) + + df = pd.DataFrame(results) + df.to_csv("rowwise_scaled_linear_cutlass_time_results.csv", index=False) + print(df.to_markdown(index=False)) diff --git a/benchmarks/benchmark_s8s4_cutlass.py b/benchmarks/benchmark_s8s4_cutlass.py deleted file mode 100644 index fbf07ebb35..0000000000 --- a/benchmarks/benchmark_s8s4_cutlass.py +++ /dev/null @@ -1,52 +0,0 @@ -import pandas as pd -import torch -from tqdm import tqdm - -from torchao.ops import s8s4_linear_cutlass -from torchao.utils import benchmark_torch_function_in_microseconds - - -def get_problem(m, n, k): - dev = torch.device("cuda") - A_ref = torch.randn((m, k), dtype=torch.half, device=dev) - B_ref = torch.randn((k, n), dtype=torch.half, device=dev) - - A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev) - A_scale = torch.randn((m,), dtype=torch.half, device=dev) - B = torch.randint(-128, 127, size=(n, k // 2), dtype=torch.int8, device=dev) - B_scale = torch.randn((n,), dtype=torch.half, device=dev) - C = None - - return A_ref, B_ref, A, A_scale, B, B_scale, C - - -def benchmark(m: int, k: int, n: int): - A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k) - - fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref) - s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds( - s8s4_linear_cutlass, A, A_scale, B, B_scale, C - ) - - return { - "m": m, - "k": k, - "n": n, - "fp16_latency (ms)": fp16_time, - "s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time, - "speedup (d/s)": fp16_time / s8s4_linear_cutlass_time, - } - - -if __name__ == "__main__": - k_vals = (8192, 8192, 8192, 28672) - n_vals = (8192, 10240, 57344, 8192) - - results = [] - for m in tqdm([1 << i for i in range(10)]): - for n, k in zip(n_vals, k_vals): - results.append(benchmark(m, k, n)) - - df = pd.DataFrame(results) - df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False) - print(df.to_markdown(index=False)) diff --git a/setup.py b/setup.py index 8628dc7ef4..67a8d2e576 100644 --- a/setup.py +++ b/setup.py @@ -240,30 +240,42 @@ def get_extensions(): extra_compile_args["nvcc"].append("-g") extra_link_args.append("/DEBUG") + this_dir = os.path.dirname(os.path.curdir) + extensions_dir = os.path.join(this_dir, "torchao", "csrc") + sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) + + extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + cuda_sources = list( + glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) + ) + + if use_cuda: + sources += cuda_sources + use_cutlass = False if use_cuda and not IS_WINDOWS: use_cutlass = True cutlass_dir = os.path.join(third_party_path, "cutlass") cutlass_include_dir = os.path.join(cutlass_dir, "include") + cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir) if use_cutlass: extra_compile_args["nvcc"].extend( [ "-DTORCHAO_USE_CUTLASS", "-I" + cutlass_include_dir, + "-I" + cutlass_extensions_include_dir, ] ) - - this_dir = os.path.dirname(os.path.curdir) - extensions_dir = os.path.join(this_dir, "torchao", "csrc") - sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) - - extensions_cuda_dir = os.path.join(extensions_dir, "cuda") - cuda_sources = list( - glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) - ) - - if use_cuda: - sources += cuda_sources + else: + # Remove CUTLASS-based kernels from the cuda_sources list. An + # assumption is that these files will have "cutlass" in its + # name. + cutlass_sources = list( + glob.glob( + os.path.join(extensions_cuda_dir, "**/*cutlass*.cu"), recursive=True + ) + ) + sources = [s for s in sources if s not in cutlass_sources] ext_modules = [] if len(sources) > 0: diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 8be0652e9a..52b25dab82 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -11,6 +11,7 @@ from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, + int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -61,6 +62,7 @@ def get_quantization_functions( layout=CutlassInt4PackedLayout(), ) ) + base_functions.append(int4_dynamic_activation_int4_weight()) if do_sparse: base_functions.append( diff --git a/test/test_rowwise_scaled_linear_cutlass.py b/test/test_rowwise_scaled_linear_cutlass.py new file mode 100644 index 0000000000..d6203ab9a4 --- /dev/null +++ b/test/test_rowwise_scaled_linear_cutlass.py @@ -0,0 +1,104 @@ +import itertools + +import pytest +import torch + +from torchao.ops import ( + rowwise_scaled_linear_cutlass_s4s4, + rowwise_scaled_linear_cutlass_s8s4, +) +from torchao.quantization.utils import group_quantize_tensor_symmetric + +ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] +ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] +ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [ + (2, 512, 128), + (3, 2048, 2048), + (4, 3584, 640), + (13, 8704, 8576), + (26, 18944, 1664), + (67, 6656, 1408), +] +ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS = [False, True] +ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS = list( + itertools.product( + ROWWISE_SCALED_LINEAR_CUTLASS_DTYPE, + ROWWISE_SCALED_LINEAR_CUTLASS_BATCH_SIZE, + ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK, + ROWWISE_SCALED_LINEAR_CUTLASS_USE_BIAS, + ) +) + + +def run_test_for_op(op, xq_bits, wq_bits, dtype, batch_size, size_mnk, use_bias): + assert xq_bits in [4, 8] + assert wq_bits in [4, 8] + + size_m, size_n, size_k = size_mnk + + x = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") + w = torch.rand((size_n, size_k), dtype=dtype, device="cuda") + bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None + + x_2d = x.view(-1, x.shape[-1]) + xq_2d_s8, xq_2d_scales, xq_2d_zeros = group_quantize_tensor_symmetric( + x_2d, xq_bits, size_k, dtype + ) + assert torch.all(xq_2d_zeros == 0) + xq_s8 = xq_2d_s8.reshape(x.shape) + if xq_bits == 4: + xq = (xq_s8[..., 1::2] << 4) | (xq_s8[..., 0::2] & 0xF) + else: + xq = xq_s8 + xq_scales = xq_2d_scales.reshape(x.shape[:-1]) + + wq_s8, wq_scales, wq_zeros = group_quantize_tensor_symmetric( + w, wq_bits, size_n, dtype + ) + assert torch.all(wq_zeros == 0) + if wq_bits == 4: + wq = (wq_s8[:, 1::2] << 4) | (wq_s8[:, 0::2] & 0xF) + else: + wq = wq_s8 + + # If torch.nn.functional.linear(x, w, bias) used as reference, the + # error would be too big. The calculation below is approximately + # what rowwise_scaled_linear_cutlass kernel is doing (except that + # matrix multiplication is over integers there). + size_m_2d = x_2d.shape[0] + output_ref = ( + (xq_2d_s8.float() @ wq_s8.float().T) + * xq_2d_scales.view(size_m_2d, 1) + * wq_scales.view(1, size_n) + ) + if bias is not None: + output_ref += bias + output_ref = output_ref.to(dtype).reshape(x.shape[:-1] + (size_n,)) + + fn_inputs = (xq, xq_scales, wq, wq_scales, bias) + try: + output = op(*fn_inputs) + except NotImplementedError: + pytest.xfail("operator not implemented") + + torch.testing.assert_close(output, output_ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS +) +def test_rowwise_scaled_linear_cutlass_s4s4(dtype, batch_size, size_mnk, use_bias): + run_test_for_op( + rowwise_scaled_linear_cutlass_s4s4, 4, 4, dtype, batch_size, size_mnk, use_bias + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "dtype, batch_size, size_mnk, use_bias", ROWWISE_SCALED_LINEAR_CUTLASS_TEST_PARAMS +) +def test_rowwise_scaled_linear_cutlass_s8s4(dtype, batch_size, size_mnk, use_bias): + run_test_for_op( + rowwise_scaled_linear_cutlass_s8s4, 8, 4, dtype, batch_size, size_mnk, use_bias + ) diff --git a/test/test_s8s4_linear_cutlass.py b/test/test_s8s4_linear_cutlass.py deleted file mode 100644 index 6510adaea3..0000000000 --- a/test/test_s8s4_linear_cutlass.py +++ /dev/null @@ -1,77 +0,0 @@ -import itertools - -import pytest -import torch - -from torchao.ops import s8s4_linear_cutlass -from torchao.quantization.utils import group_quantize_tensor_symmetric -from torchao.utils import compute_max_diff - -S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16] -S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64] -S8S4_LINEAR_CUTLASS_SIZE_MNK = [ - (2, 512, 128), - (3, 2048, 2048), - (4, 3584, 640), - (13, 8704, 8576), - (26, 18944, 1664), - (67, 6656, 1408), -] -S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True] -S8S4_LINEAR_CUTLASS_TEST_PARAMS = list( - itertools.product( - S8S4_LINEAR_CUTLASS_DTYPE, - S8S4_LINEAR_CUTLASS_BATCH_SIZE, - S8S4_LINEAR_CUTLASS_SIZE_MNK, - S8S4_LINEAR_CUTLASS_USE_BIAS, - ) -) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize( - "dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS -) -def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias): - size_m, size_n, size_k = size_mnk - - input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda") - weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda") - bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None - - input_2d = input.view(-1, input.shape[-1]) - input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric( - input_2d, 8, size_k, dtype - ) - assert torch.all(input_2d_zeros == 0) - input_s8 = input_2d_s8.reshape(input.shape) - input_scales = input_2d_scales.reshape(input.shape[:-1]) - - weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric( - weight, 4, size_n, dtype - ) - assert torch.all(weight_zeros == 0) - weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF) - - # If torch.nn.functional.linear(input, weight, bias) used as - # reference, the error would be too big. The calculation below is - # approximately what s8s4_linear_cutlass kernel is doing (except - # that matrrix multiplication is over integers there)). - size_m_2d = input_2d.shape[0] - output_ref = ( - (input_2d_s8.to(dtype) @ weight_s8.to(dtype).T) - * input_2d_scales.view(size_m_2d, 1) - * weight_scales.view(1, size_n) - ) - if bias is not None: - output_ref += bias - output_ref = output_ref.reshape(input.shape[:-1] + (size_n,)) - - fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias) - try: - output = s8s4_linear_cutlass(*fn_inputs) - except NotImplementedError: - pytest.xfail("s8s4_linear_cutlass() op not implemented") - - max_diff = compute_max_diff(output, output_ref) - assert max_diff < 5e-3 diff --git a/torchao/csrc/README.md b/torchao/csrc/README.md index 1910e3d6e5..eaa08f04f7 100644 --- a/torchao/csrc/README.md +++ b/torchao/csrc/README.md @@ -8,7 +8,6 @@ The goal is that you can focus on just writing your custom CUDA or C++ kernel an To learn more about custom ops in PyTorch you can refer to the [PyTorch Custom Operators Landing Page](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html) - ## How to add your own kernel in ao We've integrated several kernels which you can use as a template for your own kernels. `tensor_core_tiled_layout` is the most straight-forward to get started with. @@ -23,6 +22,8 @@ And that's it! Once CI passes and your code merged you'll be able to point peopl If you'd like to learn more please check out [torch.library](https://pytorch.org/docs/main/library.html) +Note: All CUTLASS-based kernels should have `cutlass` in the name of their `.cu` files e.g. `rowwise_scaled_linear_cutlass_s4s4.cu` + ## Required dependencies The important dependencies are already taken care of in our CI so feel free to test in CI directly diff --git a/torchao/csrc/cuda/cutlass_extensions/common.h b/torchao/csrc/cuda/cutlass_extensions/common.h new file mode 100644 index 0000000000..f6024a752a --- /dev/null +++ b/torchao/csrc/cuda/cutlass_extensions/common.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +#define CUTLASS_STATUS_CHECK(status, message_prefix) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, message_prefix, \ + " : Got CUTLASS error: ", cutlassGetStatusString(status)); \ + } + +namespace torchao { + +template +struct enable_2x_kernel_for_sm80_or_later : Kernel { + template + CUTLASS_DEVICE static void invoke(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 + Kernel::invoke(std::forward(args)...); +#endif + } +}; + +template +struct enable_3x_kernel_for_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); +#endif + } +}; + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/README.md b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/README.md new file mode 100644 index 0000000000..7c36f7c7ed --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/README.md @@ -0,0 +1,52 @@ +This directory is intended to contain implementations for all of the +CUTLASS-based row-wise scaled linear operators, for non-sparse inputs +of both same and mixed data types. + +The implementation is through single kernel per SM generation, that +should reside in `rowwise_scaled_linear_kernel_cutlass.cuh` file. At +the moment, only SM8.x architectures are supported, through +`rowwise_scaled_linear_kernel_cutlass_sm8x` kernel, but the SM9.x, and +eventually higher, can and will be supported too. + +The rest of source files, besides +`rowwise_scaled_linear_kernel_cutlass.cuh` file, contain just the +corresponding template instantiation and PyTorch operator declaration +for given operator. + +In order to support new combination of data types, copy one of +existing `.cu` files, for example +`rowwise_scaled_linear_kernel_cutlass_s8s4.cu`, rename the new file, +as well as operator to be defined inside, to reflect data types to be +supported, and also change `using ElementA` and `using ElementB` +directives accordingly. + +In the `.cuh` file, looking from the bottom up, the changes needed as +follows: + +1. Optionally, in the `rowwise_scaled_linear_cutlass_check_inputs` +template, changes may be needed at the places where the last dimension +of first operand is checked - but this check will have to be updated +only for inputs of mixed data types, where wider data type is not +exactly two times wider than the other data type. +2. In the `select_config` template, a section should be added to +choose optimal configuration(s) for your kernel. The configuration +selection is critical for performance of any CUTLASS-based kernel, so +this is where the most time should and will be spent when making +changes. +3. Optionally, in the `rowwise_scaled_linear_kernel_cutlass_sm8x` +template, `using Operator` directive may need to be adjusted; namely, +for some combination of operands, `OpMultiplyAdd` may have to be used. + +After making these changes, the test file +`tests/test_rowwise_scaled_linear_cutlass.py` should be changed too - +add a test for the new operator alike to existing tests. + +To restrict build times, the implementation in `.cuh` file has some +restrictions at the moment, for example: scale tensors could be only +of `float16` or `bfloat16` data types, the output is produces to be of +the same data type as first input scale tensor, scale tensors are not +optional while bias is optional, etc. If any of these restrictions +should be removed, or if any alike changes are needed, or if support +for other architectures is needed, or if you need any kind of help in +extending this code to support other data type combinations - get in +touch with the developers. diff --git a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh similarity index 53% rename from torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu rename to torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh index 6253f8d5f7..0117f12e27 100644 --- a/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass.cuh @@ -1,4 +1,4 @@ -#include +#pragma once #include #include @@ -7,61 +7,68 @@ #if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ defined(CUDA_VERSION) && (CUDA_VERSION >= 11080) -#define BUILD_S8S4_LINEAR_CUTLASS +#define BUILD_ROWWISE_SCALED_LINEAR_CUTLASS #endif -#if defined(BUILD_S8S4_LINEAR_CUTLASS) -#include -#include +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) #include -#include +#include #include +#include -#define CUTLASS_STATUS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - __func__, " : Got CUTLASS error: ", \ - cutlassGetStatusString(status)); \ - } +#include "cutlass_extensions/common.h" #endif +#define OPERATOR_NAME "rowwise_scaled_linear_cutlass" + namespace torchao { -#if defined(BUILD_S8S4_LINEAR_CUTLASS) +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) template< typename ThreadblockShape, typename WarpShape, typename InstructionShape, + typename ThreadblockSwizzle, int NumStages, typename ElementA, typename ElementB, - typename ElementAccumulator, - typename Operator, - typename ElementAScale, - typename ElementBScale, + typename ElementOutput, typename ElementC, typename UseTensorC, - typename ElementOutput> -void s8s4_linear_kernel_cutlass_sm8x( + typename ElementAScale, + typename ElementBScale> +void rowwise_scaled_linear_kernel_cutlass_sm8x( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { + static_assert((cutlass::sizeof_bits::value >= 8 || + 8 % cutlass::sizeof_bits::value == 0) && + (cutlass::sizeof_bits::value >= 8 || + 8 % cutlass::sizeof_bits::value == 0)); + using SmArch = cutlass::arch::Sm80; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutOutput = cutlass::layout::RowMajor; - using ElementEpilogue = float; + // TODO: use FP32 if either ElementA/B is FP + using ElementAccumulator = int32_t; + using Operator = + std::conditional_t::value, + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAddMixedInputUpcast>; - using ThreadblockSwizzle = - cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; + using ElementEpilogue = float; constexpr auto NumEVTEpilogueStages = 1; const int m = tensor_a.size(0); const int n = tensor_b.size(0); - const int k = tensor_a.size(1); + int k = tensor_a.size(1); + if constexpr (cutlass::sizeof_bits::value < 8) { + k *= 8 / cutlass::sizeof_bits::value; + } constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; constexpr int AlignmentAScale = @@ -74,37 +81,16 @@ void s8s4_linear_kernel_cutlass_sm8x( 128 / cutlass::sizeof_bits::value; // Check for current CUTLASS limitations w.r.t. alignments. - TORCH_CHECK(k % AlignmentA == 0, - __func__, " : Number of columns of tensor A must be divisible ", - "by ", AlignmentA); - TORCH_CHECK(k % AlignmentB == 0, - __func__, " : Number of columns of tensor B must be divisible ", - "by ", AlignmentB); - TORCH_CHECK(n % AlignmentC == 0, - __func__, " : Number of columns of tensor C must be divisible ", - "by ", AlignmentC); - - using TensorAScaleTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementAScale, - AlignmentAScale, - NumEVTEpilogueStages>; - using TensorBScaleTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementBScale, - AlignmentBScale, - NumEVTEpilogueStages>; - using TensorCTileThreadMap = - cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementC, - AlignmentC, - NumEVTEpilogueStages>; + TORCH_CHECK(k % AlignmentA == 0, OPERATOR_NAME, + " : Number of columns of tensor A must be divisible by ", + AlignmentA); + TORCH_CHECK(k % AlignmentB == 0, OPERATOR_NAME, + " : Number of columns of tensor B must be divisible by ", + AlignmentB); + TORCH_CHECK(n % AlignmentC == 0, OPERATOR_NAME, + " : Number of columns of tensor C must be divisible by ", + AlignmentC); + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< ThreadblockShape, @@ -117,14 +103,14 @@ void s8s4_linear_kernel_cutlass_sm8x( using TensorAScale = cutlass::epilogue::threadblock::VisitorColBroadcast< - TensorAScaleTileThreadMap, + OutputTileThreadMap, ElementAScale, cute::Stride>; using TensorAScaleArguments = typename TensorAScale::Arguments; using TensorBScale = cutlass::epilogue::threadblock::VisitorRowBroadcast< - TensorBScaleTileThreadMap, + OutputTileThreadMap, ElementBScale, cute::Stride>; using TensorBScaleArguments = typename TensorBScale::Arguments; @@ -133,7 +119,7 @@ void s8s4_linear_kernel_cutlass_sm8x( cutlass::epilogue::threadblock::VisitorScalarBroadcast; using TensorCTensor = cutlass::epilogue::threadblock::VisitorRowBroadcast< - TensorCTileThreadMap, + OutputTileThreadMap, ElementC, cute::Stride>; using TensorC = @@ -177,26 +163,26 @@ void s8s4_linear_kernel_cutlass_sm8x( Output, EVTApplySum>; - using EVTKernel = + using EVTKernel = torchao::enable_2x_kernel_for_sm80_or_later< typename cutlass::gemm::kernel::DefaultGemmWithVisitor< - ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, - ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, - ElementOutput, LayoutOutput, AlignmentOutput, - ElementAccumulator, - ElementEpilogue, - cutlass::arch::OpClassTensorOp, - SmArch, - ThreadblockShape, - WarpShape, - InstructionShape, - EVTOutput, - ThreadblockSwizzle, - NumStages, - Operator, - NumEVTEpilogueStages - >::GemmKernel; - - using Gemm = cutlass::gemm::device::GemmUniversalBase; + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementOutput, LayoutOutput, AlignmentOutput, + ElementAccumulator, + ElementEpilogue, + cutlass::arch::OpClassTensorOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EVTOutput, + ThreadblockSwizzle, + NumStages, + Operator, + NumEVTEpilogueStages + >::GemmKernel>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; cutlass::gemm::GemmCoord problem_size(m, n, k); constexpr auto SplitKFactor = 1; @@ -242,7 +228,6 @@ void s8s4_linear_kernel_cutlass_sm8x( }, // EVTApplySum output_arguments // Output }; // EVTOutput - constexpr auto AvailSms = -1; typename Gemm::Arguments arguments( cutlass::gemm::GemmUniversalMode::kGemm, @@ -260,8 +245,8 @@ void s8s4_linear_kernel_cutlass_sm8x( problem_size.k(), // stride A problem_size.k(), // stride B 0, // stride C (unused) - 0, // stride D (unused) - AvailSms); + 0 // stride D (unused) + ); Gemm gemm_op; @@ -270,7 +255,7 @@ void s8s4_linear_kernel_cutlass_sm8x( // Verify that GEMM operation with given arguments can be performed // by CUTLASS. status = gemm_op.can_implement(arguments); - CUTLASS_STATUS_CHECK(status); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); // Allocate workspace for CUTLASS mixed datatypes GEMM kernel. const auto workspace_size = Gemm::get_workspace_size(arguments); @@ -280,11 +265,11 @@ void s8s4_linear_kernel_cutlass_sm8x( // Initialize CUTLASS mixed datatypes GEMM object. status = gemm_op.initialize(arguments, workspace.data_ptr(), at::cuda::getCurrentCUDAStream()); - CUTLASS_STATUS_CHECK(status); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); // Perform mixed datatypes GEMM operation. status = gemm_op.run(at::cuda::getCurrentCUDAStream()); - CUTLASS_STATUS_CHECK(status); + CUTLASS_STATUS_CHECK(status, OPERATOR_NAME); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -293,14 +278,61 @@ template static void select_config( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { + const at::Tensor& tensor_c, at::Tensor& tensor_d) { const auto dprops = at::cuda::getCurrentDeviceProperties(); const auto is_sm8x = dprops->major == 8; if (is_sm8x) { - if constexpr (std::is_same::value && + if constexpr (std::is_same::value && + std::is_same::value) { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>; + + // some basic tuning + if (tensor_a.size(0) <= 16) { + using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 256>; + constexpr auto NumStages = 5; + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else if (tensor_a.size(0) <= 32) { + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 256>; + constexpr auto NumStages = 4; + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else if (tensor_a.size(0) <= 128) { + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; + constexpr auto NumStages = 4; + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } else { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 256, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + constexpr auto NumStages = 4; + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( + tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, + tensor_d); + } + return; + } else if constexpr (std::is_same::value && std::is_same::value) { using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using ThreadblockSwizzle = + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK; // A minimal heuristic to improve performance for small number // of inputs cases. @@ -308,27 +340,27 @@ static void select_config( using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<16, 32, 128>; constexpr auto NumStages = 6; - s8s4_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, - ElementB, Types...>( + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); } else if (tensor_a.size(0) <= 32) { using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; constexpr auto NumStages = 5; - s8s4_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, - ElementB, Types...>( + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); } else { using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; constexpr auto NumStages = 4; - s8s4_linear_kernel_cutlass_sm8x< - ThreadblockShape, WarpShape, InstructionShape, NumStages, ElementA, - ElementB, Types...>( + rowwise_scaled_linear_kernel_cutlass_sm8x< + ThreadblockShape, WarpShape, InstructionShape, ThreadblockSwizzle, + NumStages, ElementA, ElementB, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); } @@ -336,41 +368,15 @@ static void select_config( } } - TORCH_CHECK(false, - __func__, " : Operator not supported on SM", dprops->major, ".", - dprops->minor, " for given operands"); -} - -template -static void -dispatch_on_tensor_a_and_tensor_b( - const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, - const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, - const at::Tensor& tensor_c, at::Tensor& tensor_d) { - if (tensor_a.scalar_type() == at::ScalarType::Char) { - if (tensor_b.scalar_type() == at::ScalarType::Char) { - if (tensor_a.size(1) == 2 * tensor_b.size(1)) { - using ElementA = int8_t; - using ElementB = cutlass::int4b_t; - using ElementAccumulator = int32_t; - using Operator = cutlass::arch::OpMultiplyAddMixedInputUpcast; - select_config< - ElementA, ElementB, ElementAccumulator, Operator, Types...>( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, - tensor_d); - } - return; - } - } - - TORCH_CHECK(false, - __func__, " : Operator not supported for combination of data ", - "types ", tensor_a.scalar_type(), " for first operand and ", - tensor_b.scalar_type(), " for second operand"); + TORCH_CHECK(false, OPERATOR_NAME, " : Operator not supported on SM", + dprops->major, ".", dprops->minor, " for given operands"); } - -template +template< + typename ElementA, + typename ElementB, + typename ElementOutput, + typename... Types> static void dispatch_on_tensor_c( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, @@ -379,8 +385,8 @@ dispatch_on_tensor_c( if (tensor_c.numel() == 0) { using ElementC = ElementOutput; using UseTensorC = std::false_type; - dispatch_on_tensor_a_and_tensor_b< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + select_config< + ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); return; @@ -389,32 +395,32 @@ dispatch_on_tensor_c( using UseTensorC = std::true_type; if (tensor_c.scalar_type() == at::ScalarType::Half) { using ElementC = cutlass::half_t; - dispatch_on_tensor_a_and_tensor_b< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + select_config< + ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); return; } else if (tensor_c.scalar_type() == at::ScalarType::BFloat16) { using ElementC = cutlass::bfloat16_t; - dispatch_on_tensor_a_and_tensor_b< - ElementAScale, ElementBScale, ElementC, UseTensorC, ElementOutput>( + select_config< + ElementA, ElementB, ElementOutput, ElementC, UseTensorC, Types...>( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); return; } - TORCH_CHECK(false, - __func__, " : Operator not supported for datatype ", - tensor_c.scalar_type(), " for addend"); + TORCH_CHECK(false, OPERATOR_NAME, " : Operator not supported for datatype ", + tensor_c.scalar_type(), " for addend"); } +template static void dispatch_on_tensor_a_scale_and_tensor_b_scale( const at::Tensor& tensor_a, const at::Tensor& tensor_a_scale, const at::Tensor& tensor_b, const at::Tensor& tensor_b_scale, const at::Tensor& tensor_c, at::Tensor& tensor_d) { TORCH_CHECK(tensor_d.scalar_type() == tensor_a_scale.scalar_type(), - __func__, " : Operator not supported for output datatype ", + OPERATOR_NAME, " : Operator not supported for output datatype ", tensor_d.scalar_type(), " as it's different from the first ", " operand scale datatype ", tensor_a_scale.scalar_type()); @@ -423,7 +429,8 @@ dispatch_on_tensor_a_scale_and_tensor_b_scale( using ElementAScale = cutlass::half_t; using ElementBScale = cutlass::half_t; using ElementOutput = cutlass::half_t; - dispatch_on_tensor_c( + dispatch_on_tensor_c( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); return; } else if (tensor_a_scale.scalar_type() == at::ScalarType::BFloat16 && @@ -431,151 +438,144 @@ dispatch_on_tensor_a_scale_and_tensor_b_scale( using ElementAScale = cutlass::bfloat16_t; using ElementBScale = cutlass::bfloat16_t; using ElementOutput = cutlass::bfloat16_t; - dispatch_on_tensor_c( + dispatch_on_tensor_c( tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); return; } - TORCH_CHECK(false, - __func__, " : Operator not supported for combination of data ", - "types ", tensor_a_scale.scalar_type(), - " for first operand scale and ", tensor_b_scale.scalar_type(), - " for second operand scale"); + TORCH_CHECK(false, OPERATOR_NAME, + " : Operator not supported for combination of data types ", + tensor_a_scale.scalar_type(), " for first operand scale and ", + tensor_b_scale.scalar_type(), " for second operand scale"); } +template void -check_inputs( +rowwise_scaled_linear_cutlass_check_inputs( const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, - const at::Tensor& w_scale, const at::Tensor& bias) { + const at::Tensor& w_scale, const at::Tensor& bias){ // Validate layouts of arguments. - TORCH_CHECK(xq.dim() >= 2, - __func__, " : Expected xq argument to be 2D or " - "higher-dimensional tensor, got ", xq.dim(), " dims"); - TORCH_CHECK(xq.layout() == at::Layout::Strided, - __func__, " : Expected xq argument to be strided, got layout ", + TORCH_CHECK(xq.dim() >= 2, OPERATOR_NAME, + " : Expected xq argument to be 2D or higher-dimensional tensor, " + "got ", xq.dim(), " dims"); + TORCH_CHECK(xq.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected xq argument to be strided, got layout ", xq.layout()); - TORCH_CHECK(x_scale.dim() == xq.dim() - 1, - __func__, " : Expected xq scale argument to be ", xq.dim() - 1, + TORCH_CHECK(x_scale.dim() == xq.dim() - 1, OPERATOR_NAME, + " : Expected xq scale argument to be ", xq.dim() - 1, "D tensor, got ", x_scale.dim(), " dims"); - TORCH_CHECK(x_scale.layout() == at::Layout::Strided, - __func__, " : Expected xq scale argument to be strided, got " - "layout ", x_scale.layout()); - TORCH_CHECK(wq.dim() == 2, - __func__, " : Expected wq argument to be 2D tensor, got ", - wq.dim(), " dims"); - TORCH_CHECK(wq.layout() == at::Layout::Strided, - __func__, " : Expected wq argument to be strided, got layout ", + TORCH_CHECK(x_scale.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected xq scale argument to be strided, got layout ", + x_scale.layout()); + TORCH_CHECK(wq.dim() == 2, OPERATOR_NAME, + " : Expected wq argument to be 2D tensor, got ", wq.dim(), + " dims"); + TORCH_CHECK(wq.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected wq argument to be strided, got layout ", wq.layout()); - TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2, - __func__, " : Expected wq scale argument to be 1D or 2D tensor, ", - "got ", w_scale.dim(), " dims"); - TORCH_CHECK(w_scale.layout() == at::Layout::Strided, - __func__, " : Expected wq scale argument to be strided, got " - "layout ", w_scale.layout()); + TORCH_CHECK(w_scale.dim() == 1 || w_scale.dim() == 2, OPERATOR_NAME, + " : Expected wq scale argument to be 1D or 2D tensor, ", "got ", + w_scale.dim(), " dims"); + TORCH_CHECK(w_scale.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected wq scale argument to be strided, got layout ", + w_scale.layout()); if (bias.numel() > 0) { - TORCH_CHECK(bias.dim() == 1, - __func__, " : Expected bias argument to be 1D tensor, got ", - bias.dim(), " dims"); - TORCH_CHECK(bias.layout() == at::Layout::Strided, - __func__, " : Expected bias argument to be strided, got ", - "layout ", bias.layout()); + TORCH_CHECK(bias.dim() == 1, OPERATOR_NAME, + " : Expected bias argument to be 1D tensor, got ", bias.dim(), + " dims"); + TORCH_CHECK(bias.layout() == at::Layout::Strided, OPERATOR_NAME, + " : Expected bias argument to be strided, got layout ", + bias.layout()); } // Validate sizes of arguments. const auto xq_sizes = xq.sizes().vec(); - TORCH_CHECK(xq_sizes.back() == 2 * wq.size(1), - __func__, " : Expected xq argument to have ", 2 * wq.size(1), - " columns, but got ", xq_sizes.back()); + TORCH_CHECK(xq_sizes.back() == wq.size(1) || + xq_sizes.back() == 2 * wq.size(1), + OPERATOR_NAME, " : Expected xq argument to have ", wq.size(1), + " or ", 2 * wq.size(1), " columns, but got ", xq_sizes.back()); const auto x_scale_sizes = x_scale.sizes().vec(); for (auto i = 0; i < x_scale_sizes.size(); ++i) - TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], - __func__, " : Expected xq scale argument size at position ", - i, " to be ", xq_sizes[i], ", but got ", x_scale_sizes[i]); - TORCH_CHECK(w_scale.numel() == wq.size(0), - __func__, " : Expected wq scale argument to have ", wq.size(0), + TORCH_CHECK(x_scale_sizes[i] == xq_sizes[i], OPERATOR_NAME, + " : Expected xq scale argument size at position ", i, " to be ", + xq_sizes[i], ", but got ", x_scale_sizes[i]); + TORCH_CHECK(w_scale.numel() == wq.size(0), OPERATOR_NAME, + " : Expected wq scale argument to have ", wq.size(0), " elements, got ", w_scale.numel(), " elements"); if (bias.numel() > 0) { - TORCH_CHECK(bias.numel() == wq.size(0), - __func__, " : Expected bias argument to have ", wq.size(0), + TORCH_CHECK(bias.numel() == wq.size(0), OPERATOR_NAME, + " : Expected bias argument to have ", wq.size(0), " elements, got ", bias.numel(), " elements"); } // Validate strides of arguments. const auto xq_strides = xq.strides(); - TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1, - __func__, " : Expected xq argument in row-major layout"); + TORCH_CHECK(xq_strides[xq_strides.size() - 1] == 1, OPERATOR_NAME, + " : Expected xq argument in row-major layout"); auto xq_stride_expected = xq_strides[xq_strides.size() - 2]; for (int i = xq_strides.size() - 3; i >= 0; --i) { xq_stride_expected *= xq_sizes[i + 1]; - TORCH_CHECK(xq_strides[i] == xq_stride_expected, - __func__, " : Expected xq argument in row-major layout"); + TORCH_CHECK(xq_strides[i] == xq_stride_expected, OPERATOR_NAME, + " : Expected xq argument in row-major layout"); } - TORCH_CHECK(x_scale.is_contiguous(), - __func__, " : Expected xq scale argument to be contiguous"); + TORCH_CHECK(x_scale.is_contiguous(), OPERATOR_NAME, + " : Expected xq scale argument to be contiguous"); const auto wq_strides = wq.strides(); - TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1, - __func__, " : Expected wq argument in row-major layout"); - TORCH_CHECK(w_scale.is_contiguous(), - __func__, " : Expected wq scale argument to be contiguous"); + TORCH_CHECK(wq_strides[0] >= 1 && wq_strides[1] == 1, OPERATOR_NAME, + " : Expected wq argument in row-major layout"); + TORCH_CHECK(w_scale.is_contiguous(), OPERATOR_NAME, + " : Expected wq scale argument to be contiguous"); if (bias.numel() > 0) { const auto bias_strides = bias.strides(); - TORCH_CHECK(bias_strides[0] == 1, - __func__, " : Expected bias argument to be contiguous"); + TORCH_CHECK(bias_strides[0] == 1, OPERATOR_NAME, + " : Expected bias argument to be contiguous"); } } #endif -// Perform linear operation, using corresponding CUTLASS mixed -// data-types GEMM kernel, to given arguments: -// result = (xq * x_scale) @ (wq * w_scale).T + bias -// Notes: The "x_scale" tensor is expected to be a vector, of size -// equal to number of rows of "xq" tensor. The "w_scale" tensor is -// expected to be a vector, of size equal to number of rows of "wq" -// tensor. The "bias" tensor is expected to be a vector, of size equal -// to number of rows of "wq" tensor. +// Perform linear operation, using corresponding CUTLASS datatypes +// GEMM kernel, to given arguments - result produced is: +// (tensor_a * tensor_a_scale) @ (tensor_b * tensor_b_scale).T + tensor_c +// +// Notes: The "tensor_a" and "tensor_b" are expected to be 2D tensors. +// The "tensor_a_scale" tensor is expected to be a vector, of size +// equal to number of rows of "tensor_a" tensor. The "tensor_b_scale" +// tensor is expected to be a vector, of size equal to number of rows +// of "tensor_b" tensor. The "tensor_c" tensor is expected to be a +// vector, of size equal to number of rows of "tensor_b" tensor. +template at::Tensor -s8s4_linear_cutlass( +rowwise_scaled_linear_cutlass( const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, const at::Tensor& w_scale, const at::Tensor& bias) { -#if defined(BUILD_S8S4_LINEAR_CUTLASS) +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) // Check inputs. - check_inputs(xq, x_scale, wq, w_scale, bias); + rowwise_scaled_linear_cutlass_check_inputs( + xq, x_scale, wq, w_scale, bias); // Squash the input tensors as appropriate. const auto xq_sizes = xq.sizes().vec(); const auto xq_2d = xq.reshape({-1, xq_sizes.back()}); - const auto x_scale_sizes = x_scale.sizes().vec(); const auto x_scale_1d = x_scale.reshape({-1}); const auto w_scale_1d = w_scale.reshape({-1}); - // Introduce alias names for arguments, according to the CUTLASS - // naming conventions. - const auto& tensor_a = xq_2d; - const auto& tensor_a_scale = x_scale_1d; - const auto& tensor_b = wq; - const auto& tensor_b_scale = w_scale_1d; - const auto& tensor_c = bias; - - // Create output tensor. - at::Tensor tensor_d = - tensor_a_scale.new_empty({tensor_a.size(0), tensor_b.size(0)}); + // Create result tensor. + at::Tensor result = + x_scale.new_empty({xq_2d.size(0), wq.size(0)}); // Dispatch to appropriate kernel template. - dispatch_on_tensor_a_scale_and_tensor_b_scale( - tensor_a, tensor_a_scale, tensor_b, tensor_b_scale, tensor_c, tensor_d); + dispatch_on_tensor_a_scale_and_tensor_b_scale( + xq_2d, x_scale_1d, wq, w_scale_1d, bias, result); - // Reshape and return output tensor. - auto tensor_d_sizes = xq_sizes; - tensor_d_sizes.back() = wq.size(0); - return tensor_d.reshape(tensor_d_sizes); + // Reshape and return result tensor. + auto result_sizes = xq_sizes; + result_sizes.back() = wq.size(0); + return result.reshape(result_sizes); #else - TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + TORCH_CHECK_NOT_IMPLEMENTED(false, OPERATOR_NAME); return at::Tensor{}; #endif } -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::s8s4_linear_cutlass", &s8s4_linear_cutlass); -} - } // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu new file mode 100644 index 0000000000..e455b7bdf2 --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu @@ -0,0 +1,28 @@ +#include + +#include "rowwise_scaled_linear_cutlass.cuh" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_cutlass_s4s4( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { + // Validate input datatypes. + TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar, + __func__, " : The input datatypes combination ", xq.dtype(), + " for xq and ", wq.dtype(), " for wq is not supported"); + + // Dispatch to appropriate kernel template. + using ElementA = cutlass::int4b_t; + using ElementB = cutlass::int4b_t; + return rowwise_scaled_linear_cutlass( + xq, x_scale, wq, w_scale, bias); +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::rowwise_scaled_linear_cutlass_s4s4", + &rowwise_scaled_linear_cutlass_s4s4); +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu new file mode 100644 index 0000000000..680822ca7f --- /dev/null +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu @@ -0,0 +1,28 @@ +#include + +#include "rowwise_scaled_linear_cutlass.cuh" + +namespace torchao { + +at::Tensor +rowwise_scaled_linear_cutlass_s8s4( + const at::Tensor& xq, const at::Tensor& x_scale, const at::Tensor& wq, + const at::Tensor& w_scale, const at::Tensor& bias) { + // Validate input datatypes. + TORCH_CHECK(xq.dtype() == at::kChar && wq.dtype() == at::kChar, + __func__, " : The input datatypes combination ", xq.dtype(), + " for xq and ", wq.dtype(), " for wq is not supported"); + + // Dispatch to appropriate kernel template. + using ElementA = int8_t; + using ElementB = cutlass::int4b_t; + return rowwise_scaled_linear_cutlass( + xq, x_scale, wq, w_scale, bias); +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::rowwise_scaled_linear_cutlass_s8s4", + &rowwise_scaled_linear_cutlass_s8s4); +} + +} // namespace torchao diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index ef8691699e..54f4a72811 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -21,6 +21,8 @@ _linear_int8_act_int8_weight_block_sparse_impl, ) from torchao.dtypes.uintx.cutlass_int4_packed_layout import ( + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ) @@ -155,6 +157,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int4_weight_cutlass_check, _linear_int8_act_int4_weight_cutlass_impl, ), + ( + _linear_int4_act_int4_weight_cutlass_check, + _linear_int4_act_int4_weight_cutlass_impl, + ), ( _linear_fp_act_uint4_weight_cpu_check, _linear_fp_act_uint4_weight_cpu_impl, diff --git a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py index 9c0d0bb055..ae8ea78ceb 100644 --- a/torchao/dtypes/uintx/cutlass_int4_packed_layout.py +++ b/torchao/dtypes/uintx/cutlass_int4_packed_layout.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional import torch from torch.utils._python_dispatch import ( @@ -105,10 +106,10 @@ def from_plain( cls, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: Optional[torch.Tensor], _layout: Layout, ): - assert torch.all(zero_point == 0) + assert zero_point is None or torch.all(zero_point == 0) int_data_s4 = ((int_data[:, 1::2] & 0xF) << 4) | (int_data[:, 0::2] & 0xF) return cls( @@ -146,13 +147,47 @@ def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): - from torchao.ops import s8s4_linear_cutlass + from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 weight = weight_tensor.tensor_impl.int_data weight_scale = weight_tensor.tensor_impl.scale input = input_tensor.tensor_impl.int_data input_scale = input_tensor.tensor_impl.scale - out = s8s4_linear_cutlass(input, input_scale, weight, weight_scale, bias) + out = rowwise_scaled_linear_cutlass_s8s4( + input, input_scale, weight, weight_scale, bias + ) + + return out + + +def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int4(input_tensor) + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and len(input_tensor.shape) >= 2 + and input_tensor.tensor_impl.scale.dtype == input_tensor.dtype + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int4(weight_tensor) + and weight_tensor.dtype == input_tensor.dtype + and len(weight_tensor.shape) == 2 + and weight_tensor.tensor_impl.scale.dtype == weight_tensor.dtype + and len(weight_tensor.tensor_impl.scale.shape) == 1 + ) + + +def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): + from torchao.ops import rowwise_scaled_linear_cutlass_s4s4 + + weight = weight_tensor.tensor_impl.int_data + weight_scale = weight_tensor.tensor_impl.scale + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + + out = rowwise_scaled_linear_cutlass_s4s4( + input, input_scale, weight, weight_scale, bias + ) return out diff --git a/torchao/ops.py b/torchao/ops.py index f4b55c4951..8b573876f2 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -20,7 +20,10 @@ "marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor" ) lib.define( - "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" + "rowwise_scaled_linear_cutlass_s4s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" +) +lib.define( + "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) @@ -514,7 +517,7 @@ def _( return torch.empty((size_m, size_n), dtype=torch.float16, device=x.device) -def s8s4_linear_cutlass( +def rowwise_scaled_linear_cutlass_s8s4( input: Tensor, input_scale: Tensor, weight: Tensor, @@ -522,23 +525,23 @@ def s8s4_linear_cutlass( bias: Tensor, ) -> Tensor: """ - CUTLASS-based W4A8 linear operator. + CUTLASS-based row-wise scaled W4A8 linear operator. Args: - input: input tensor, quantized to 8-bit integer values. + input: quantized input tensor, in row-major layout. input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. - weight: weight matrix, quantized to 4-bit integer values, in row-major layout. + weight: quantized weight matrix, in row-major layout. weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). bias: a vector of size equal to number of rows of weight tensor, or None. Returns: output: result tensor, in row-major layout. """ - return torch.ops.torchao.s8s4_linear_cutlass.default( + return torch.ops.torchao.rowwise_scaled_linear_cutlass_s8s4.default( input, input_scale, weight, weight_scale, bias ) -@register_custom_op("torchao::s8s4_linear_cutlass") +@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s8s4") def _( input: Tensor, input_scale: Tensor, @@ -546,72 +549,46 @@ def _( weight_scale: Tensor, bias: Tensor, ) -> Tensor: - # Validate dtypes. - torch._check( - input.dtype == torch.int8, - lambda: f"input dtype {input.dtype} instead of {torch.int8}", - ) - torch._check( - input_scale.dtype in (torch.float16, torch.bfloat16), - lambda: f"input_scale dtype {input_scale.dtype} instead of {torch.float16} or {torch.bfloat16}", - ) - torch._check( - weight.dtype == torch.int8, - lambda: f"weight dtype {weight.dtype} instead of {torch.int8}", - ) - torch._check( - weight_scale.dtype == input_scale.dtype, - lambda: f"weight_scale dtype {weight_scale.dtype} instead of {input_scale.dtype}", - ) - if bias is not None: - torch._check( - bias.dtype == input_scale.dtype, - lambda: f"bias dtype {weight_scale.dtype} instead of {input_scale.dtype}", - ) - - # Validate dims. - torch._check(input.dim() >= 2, lambda: f"input is {input.dim()}D instead of >=2D") - torch._check( - input_scale.dim() == input.dim() - 1, - lambda: f"input_scale is {input_scale.dim()}D instead of {input.dim() - 1}D", - ) - torch._check(weight.dim() == 2, lambda: f"weight is {weight.dim()}D instead of 2D") - torch._check( - weight_scale.dim() == 1 or weight_scale.dim() == 2, - lambda: f"weight_scale is {weight_scale.dim()}D instead of 1D or 2D", - ) - if bias is not None: - torch._check(bias.dim() == 1, lambda: f"bias is {bias.dim()}D instead of 1D") - - # Validate shapes. - torch._check( - input.shape[-1] == 2 * weight.shape[-1], - lambda: "input and weight shapes do not match for matrix product", - ) - for i in range(input_scale.dim()): - torch._check( - input_scale.shape[i] == input.shape[i], - lambda: f"input_scale and input shapes do not match at position {i}", - ) - torch._check( - weight_scale.numel() == weight.shape[0], - lambda: f"weight_scale has {weight_scale.numel()} elements instead of {weight.shape[0]}", - ) - if bias is not None: - torch._check( - bias.numel() == weight.shape[0], - lambda: f"bias has {bias.numel()} elements instead of {weight.shape[0]}", - ) - - # Validate strides (input, input_scales and weight_scales will be - # reshape()-d by the operator, so no need to check strides for - # them). - torch._check(weight.stride(-1) == 1, lambda: "weight is not in row-major layout") - if bias is not None: - torch._check(bias.is_contiguous(), lambda: "bias is not contiguous") + # No checks here, as detailed checks are performed by the + # operator itself. return torch.empty( (*input.shape[:-1], weight.shape[0]), dtype=input_scale.dtype, device=input.device, ) + + +def rowwise_scaled_linear_cutlass_s4s4( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: + """ + CUTLASS-based row-wise scaled W4A4 linear operator. + Args: + input: quantized input tensor, in row-major layout. + input_scale: scale factors for input tensor, has to be tensor of the same shape as the input tensor, minus the last dimension. + weight: quantized weight matrix, in row-major layout. + weight_scale: scale factors for weight tensor, one value per row of weight matrix (thus also tensor of the same shape as the weight tensor, minus the last dimension). + bias: a vector of size equal to number of rows of weight tensor, or None. + Returns: + output: result tensor, in row-major layout. + """ + + return torch.ops.torchao.rowwise_scaled_linear_cutlass_s4s4.default( + input, input_scale, weight, weight_scale, bias + ) + + +@register_custom_op("torchao::rowwise_scaled_linear_cutlass_s4s4") +def _( + input: Tensor, + input_scale: Tensor, + weight: Tensor, + weight_scale: Tensor, + bias: Tensor, +) -> Tensor: + return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index d0d29cf4be..aa4a51d497 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -50,6 +50,7 @@ float8_weight_only, fpx_weight_only, gemlite_uintx_weight_only, + int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_semi_sparse_weight, @@ -102,6 +103,7 @@ "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", + "int4_dynamic_activation_int4_weight", "int8_dynamic_activation_int4_weight", "int8_dynamic_activation_int8_weight", "int8_dynamic_activation_int8_semi_sparse_weight", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7154957a21..9b7999449f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -658,6 +658,59 @@ def int8_dynamic_activation_int4_weight( ) +def apply_int4_dynamic_activation_int4_weight_quant( + weight: torch.Tensor, + layout=CutlassInt4PackedLayout(), + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, +): + if not isinstance(layout, CutlassInt4PackedLayout): + raise NotImplementedError( + f"Only CutlassInt4PackedLayout layout is supported. Received {layout}." + ) + if mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError("Only mapping_type=SYMMETRIC is supported.") + if act_mapping_type != MappingType.SYMMETRIC: + raise NotImplementedError("Only act_mapping_type=SYMMETRIC is supported.") + + weight = to_affine_quantized_intx( + weight, + mapping_type=mapping_type, + block_size=(1, weight.shape[1]), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=torch.finfo(torch.float32).eps, + zero_point_domain=ZeroPointDomain.NONE, + _layout=layout, + ) + weight = to_linear_activation_quantized( + weight, + _int4_symm_per_token_quant_cutlass, + ) + return weight + + +def int4_dynamic_activation_int4_weight( + layout=CutlassInt4PackedLayout(), + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, +): + """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear + + Args: + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric + """ + return _get_linear_subclass_inserter( + apply_int4_dynamic_activation_int4_weight_quant, + layout=layout, + mapping_type=mapping_type, + act_mapping_type=act_mapping_type, + ) + + def gemlite_uintx_weight_only( group_size: Optional[int] = 64, bit_width: int = 4, @@ -859,6 +912,20 @@ def _int8_symm_per_token_reduced_range_quant_cutlass( ) +def _int4_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor: + return to_affine_quantized_intx( + x, + mapping_type=MappingType.SYMMETRIC, + block_size=_get_per_token_block_size(x), + target_dtype=torch.int8, + quant_min=-8, + quant_max=7, + eps=1e-5, + zero_point_domain=ZeroPointDomain.NONE, + _layout=CutlassInt4PackedLayout(), + ) + + def int8_dynamic_activation_int8_weight( layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC, @@ -1300,6 +1367,7 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: _int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant, _int8_symm_per_token_reduced_range_quant_cutlass, + _int4_symm_per_token_quant_cutlass, _input_activation_quant_func_fp8, ] ) From 8afd10ed4b22b3cabd80184062c4ad58001bc68a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 5 Feb 2025 20:03:19 +0800 Subject: [PATCH 052/115] Fix compile issue for Marin qqq on sm<8.0 (#1651) * fix compile guard * remove guard on header file --- .../csrc/cuda/marlin_qqq/marlin_qqq_kernel.cu | 55 ++++--------------- 1 file changed, 10 insertions(+), 45 deletions(-) diff --git a/torchao/csrc/cuda/marlin_qqq/marlin_qqq_kernel.cu b/torchao/csrc/cuda/marlin_qqq/marlin_qqq_kernel.cu index 7380f9aff2..10c3f152bd 100644 --- a/torchao/csrc/cuda/marlin_qqq/marlin_qqq_kernel.cu +++ b/torchao/csrc/cuda/marlin_qqq/marlin_qqq_kernel.cu @@ -30,9 +30,7 @@ #include #include "base.h" -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - #include "mem.h" -#endif +#include "mem.h" template inline std::string str(T x) { @@ -41,8 +39,6 @@ inline std::string str(T x) { namespace torchao { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using I4 = Vec; // Matrix fragments for tensor core instructions; their precise layout is // documented here: @@ -208,6 +204,8 @@ __global__ void Marlin_QQQ( int prob_k, // reduction dimension k int* locks // extra global storage for barrier synchronization ) { + // host code or device code with SM >= 80. Marlin only supports SM >= 80. +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM @@ -855,47 +853,8 @@ __global__ void Marlin_QQQ( } } } -} - -#else - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin_QQQ( - const int4* __restrict__ A, // int8 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // int32 global_reduce buffer of shape - // (max_par*16*4)xn, as int8 tensor core's output is - // int32 dtype - int4* __restrict__ D, // fp16 output buffer of shape mxn - const float* __restrict__ s_tok, // fp32 activation per-token quantization - // scales of shape mx1 - const int4* __restrict__ s_ch, // fp32 weight per-channel quantization - // scales of shape 1xn - const int4* __restrict__ s_group, // fp16 weight per-group quantization - // scales of shape (k/groupsize)xn, when - // group_blocks=-1, it should be nullptr - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Marlin is not implemented yet for SM < 8.0 - TORCH_CHECK_NOT_IMPLEMENTED( - false, "marlin_qqq_gemm(..) requires CUDA_ARCH >= 8.0"); - return; -} - #endif +} // 8 warps are a good choice since every SM has 4 schedulers and having more // than 1 warp per schedule allows some more latency hiding. At the same time, @@ -1132,6 +1091,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, torch::Tensor const& s_group, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k) { + const auto dprops = at::cuda::getCurrentDeviceProperties(); + if (dprops->major < 8) { + TORCH_CHECK(false, __func__, "requires SM >= 8.0. Current device is SM", + dprops->major, ".", dprops->minor); + } + // Verify M TORCH_CHECK(size_m == a.size(0), "Shape mismatch: a.size(0) = " + str(a.size(0)) + From 8d14f0eec2fade8194c7a4767ac4ba96bfd2dd2e Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Wed, 5 Feb 2025 13:27:29 -0800 Subject: [PATCH 053/115] SAM2: more export, small perf improvements (#1673) --- .../sam2_amg_server/compile_export_utils.py | 219 +++++++++++++++--- examples/sam2_amg_server/generate_data.py | 54 ++++- .../sam2_amg_server/reproduce_experiments.py | 2 +- examples/sam2_amg_server/result.csv | 140 +++++------ .../_models/sam2/automatic_mask_generator.py | 10 +- .../sam2/modeling/sam/prompt_encoder.py | 6 + torchao/_models/sam2/sam2_image_predictor.py | 17 +- torchao/_models/sam2/utils/transforms.py | 9 +- 8 files changed, 326 insertions(+), 131 deletions(-) diff --git a/examples/sam2_amg_server/compile_export_utils.py b/examples/sam2_amg_server/compile_export_utils.py index a8f34b0943..5903f4905e 100644 --- a/examples/sam2_amg_server/compile_export_utils.py +++ b/examples/sam2_amg_server/compile_export_utils.py @@ -48,7 +48,6 @@ def forward( boxes: Optional[torch.Tensor] = None, mask_input: Optional[torch.Tensor] = None, multimask_output: bool = True, - img_idx: int = -1, ): assert high_res_feats[0].size() == (self.batch_size, 32, 256, 256) assert high_res_feats[1].size() == (self.batch_size, 64, 128, 128) @@ -73,7 +72,6 @@ def forward( assert boxes is None assert mask_input is None assert multimask_output - assert img_idx == -1 if self.predictor is None: assert self.aoti_compiled_model is not None return self.aoti_compiled_model( @@ -85,7 +83,6 @@ def forward( boxes=boxes, mask_input=mask_input, multimask_output=multimask_output, - img_idx=img_idx, ) return self.predictor._predict_masks( high_res_feats, @@ -96,7 +93,6 @@ def forward( boxes=boxes, mask_input=mask_input, multimask_output=multimask_output, - img_idx=img_idx, ) @@ -176,10 +172,137 @@ def export_model( overwrite=overwrite, ) - print(f"{task_type} cannot export _predict_masks") - return + if task_type in []: + example_input_args = () + example_input_kwargs = { + "points": ( + torch.randn( + points_per_batch, + 1, + 2, + dtype=torch.float32, + device=mask_generator.predictor.device, + ), + torch.ones( + points_per_batch, + 1, + dtype=torch.int32, + device=mask_generator.predictor.device, + ), + ), + "boxes": None, + "masks": None, + } + aot_compile( + model_directory, + "sam2_sam_prompt_encoder", + mask_generator.predictor.model.sam_prompt_encoder, + example_input_args, + sample_kwargs=example_input_kwargs, + overwrite=overwrite, + ) + + if task_type in []: + example_input_args = () + example_input_kwargs = { + "image_embeddings": torch.randn( + batch_size, + 256, + 64, + 64, + dtype=torch.float32, + device=mask_generator.predictor.device, + ), + "image_pe": torch.randn( + batch_size, + 256, + 64, + 64, + dtype=torch.float32, + device=mask_generator.predictor.device, + ), + "sparse_prompt_embeddings": torch.randn( + batch_size, + 2, + 256, + dtype=torch.float32, + device=mask_generator.predictor.device, + ), + "dense_prompt_embeddings": torch.randn( + batch_size, + 256, + 64, + 64, + dtype=torch.float32, + device=mask_generator.predictor.device, + ), + "multimask_output": True, + "repeat_image": False, + "high_res_features": [ + torch.randn( + batch_size, + 32, + 256, + 256, + dtype=mask_generator.predictor._image_dtype, + device=mask_generator.predictor.device, + ), + torch.randn( + batch_size, + 64, + 128, + 128, + dtype=mask_generator.predictor._image_dtype, + device=mask_generator.predictor.device, + ), + ], + } + aot_compile( + model_directory, + "sam2_sam_mask_decoder", + mask_generator.predictor.model.sam_mask_decoder, + example_input_args, + sample_kwargs=example_input_kwargs, + overwrite=overwrite, + ) + + if task_type in []: + example_input_args = ( + torch.randn( + points_per_batch, + 256, + 64, + 64, + dtype=mask_generator.predictor.model.sam_mask_decoder._src_dtype, + device=mask_generator.predictor.device, + ), + torch.randn( + points_per_batch, + 256, + 64, + 64, + dtype=mask_generator.predictor.model.sam_mask_decoder._src_dtype, + device=mask_generator.predictor.device, + ), + torch.randn( + points_per_batch, + 8, + 256, + dtype=mask_generator.predictor.model.sam_mask_decoder._src_dtype, + device=mask_generator.predictor.device, + ), + ) + example_input_kwargs = {} + aot_compile( + model_directory, + "sam2_sam_mask_decoder_transformer", + mask_generator.predictor.model.sam_mask_decoder.transformer, + example_input_args, + sample_kwargs=example_input_kwargs, + overwrite=overwrite, + ) - if task_type in ["sps"]: + if task_type in ["amg", "sps"]: example_input_high_res_feats = [ torch.randn( batch_size, @@ -239,7 +362,6 @@ def export_model( "boxes": None, "mask_input": None, "multimask_output": True, - "img_idx": -1, } sam2_image_predict_masks = SAM2ImagePredictor_predict_masks( @@ -301,30 +423,54 @@ def load_exported_model( pkg_m = LoadedModel(pkg) mask_generator.predictor.model.image_encoder = pkg_m - print(f"End load image encoder. Took {time.time() - t0}s") - return mask_generator - - if task_type in ["amg", "mps"]: + if task_type in ["mps"]: return mask_generator - path = Path(model_directory) / Path("sam2_image_predict_masks.pt2") - assert path.exists(), f"Expected {path} to exist" - print(f"Start load from {path}") - pkg = torch._inductor.aoti_load_package(str(path)) - if task_type == "amg": - assert points_per_batch > 1 - if task_type == "sps": - assert points_per_batch == 1 - if task_type == "mps": - assert points_per_batch is None - pkg_m = SAM2ImagePredictor_predict_masks( - None, - batch_size=batch_size, - points_per_batch=points_per_batch, - aoti_compiled_model=pkg, - furious=furious, - ) - mask_generator.predictor._predict_masks = pkg_m.forward + if task_type in []: + path = Path(model_directory) / Path("sam2_sam_prompt_encoder.pt2") + assert path.exists(), f"Expected {path} to exist" + print(f"Start load from {path}") + pkg = torch._inductor.aoti_load_package(str(path)) + pkg_m = LoadedModel(pkg) + mask_generator.predictor.model.sam_prompt_encoder.forward = pkg_m.forward + + if task_type in []: + path = Path(model_directory) / Path("sam2_sam_mask_decoder.pt2") + assert path.exists(), f"Expected {path} to exist" + print(f"Start load from {path}") + pkg = torch._inductor.aoti_load_package(str(path)) + pkg_m = LoadedModel(pkg) + mask_generator.predictor.model.sam_mask_decoder.forward = pkg_m.forward + + if task_type in []: + path = Path(model_directory) / Path("sam2_sam_mask_decoder_transformer.pt2") + assert path.exists(), f"Expected {path} to exist" + print(f"Start load from {path}") + pkg = torch._inductor.aoti_load_package(str(path)) + pkg_m = LoadedModel(pkg) + mask_generator.predictor.model.sam_mask_decoder.transformer.forward = ( + pkg_m.forward + ) + + if task_type in ["amg", "sps"]: + path = Path(model_directory) / Path("sam2_image_predict_masks.pt2") + assert path.exists(), f"Expected {path} to exist" + print(f"Start load from {path}") + pkg = torch._inductor.aoti_load_package(str(path)) + if task_type == "amg": + assert points_per_batch > 1 + if task_type == "sps": + assert points_per_batch == 1 + if task_type == "mps": + assert points_per_batch is None + pkg_m = SAM2ImagePredictor_predict_masks( + None, + batch_size=batch_size, + points_per_batch=points_per_batch, + aoti_compiled_model=pkg, + furious=furious, + ) + mask_generator.predictor._predict_masks = pkg_m.forward print(f"End load image encoder and predict masks. Took {time.time() - t0}s") @@ -352,12 +498,13 @@ def set_fast( dynamic=False, ) elif task_type == "amg": - mask_generator.predictor._predict_masks = torch.compile( - mask_generator.predictor._predict_masks, - mode="max-autotune", - fullgraph=True, - dynamic=False, - ) + if not loaded_exported_model: + mask_generator.predictor._predict_masks = torch.compile( + mask_generator.predictor._predict_masks, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) else: # TODO: This might need to be under "allow_recompiles" # mps encounters rapidly changing points per batch diff --git a/examples/sam2_amg_server/generate_data.py b/examples/sam2_amg_server/generate_data.py index 7c61a7f728..8632f0163a 100644 --- a/examples/sam2_amg_server/generate_data.py +++ b/examples/sam2_amg_server/generate_data.py @@ -21,6 +21,38 @@ from tqdm import tqdm +def profiler_runner(path, fn, *args, **kwargs): + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + ) as prof: + result = fn(*args, **kwargs) + prof.export_chrome_trace(path) + return result + + +def memory_runner(path, fn, *args, **kwargs): + print("Start memory recording") + torch.cuda.synchronize() + torch.cuda.memory._record_memory_history( + True, trace_alloc_max_entries=100000, trace_alloc_record_context=True + ) + result = fn(*args, **kwargs) + torch.cuda.synchronize() + snapshot = torch.cuda.memory._snapshot() + print("Finish memory recording") + import pickle + + with open(path, "wb") as f: + pickle.dump(snapshot, f) + # Use to convert pickle file into html + # python torch/cuda/_memory_viz.py trace_plot .pickle -o .html + return result + + def latencies_statistics(data): # Convert the list to a NumPy array data_array = np.array(data) @@ -330,16 +362,17 @@ def decode_img_bytes(img_bytes_tensors, gpu_preproc, baseline): for img_bytes_tensor in img_bytes_tensors: with record_function("decode image bytes"): if gpu_preproc: - # NOTE: We have to use numpy for the baseline - assert not baseline - from torchvision import io as tio - - image_tensor = tio.decode_jpeg( - img_bytes_tensor, device="cuda", mode=tio.ImageReadMode.RGB - ) - from torchvision.transforms.v2 import functional as F + image_tensor = file_bytes_to_image_tensor(img_bytes_tensor) + from torchvision.transforms import ToTensor, v2 - image_tensor = F.to_dtype(image_tensor, torch.float32, scale=True) + if not baseline: + image_tensor = torch.from_numpy(image_tensor) + image_tensor = image_tensor.permute((2, 0, 1)) + image_tensor = image_tensor.cuda() + with record_function("v2.ToDtype"): + image_tensor = v2.ToDtype(torch.float32, scale=True)( + image_tensor + ) else: image_tensor = file_bytes_to_image_tensor(img_bytes_tensor) from torchvision.transforms import ToTensor @@ -431,6 +464,7 @@ def main( quiet=False, gpu_preproc=False, batch_size=1, + seed=42, ): if batch_size <= 0: raise ValueError("Expected --batch_size to be at least 1 but got {batch_size}") @@ -502,6 +536,7 @@ def main( from torchao._models.sam2.utils.amg import ( mask_to_rle_pytorch_2 as mask_to_rle_pytorch, ) + torch.manual_seed(seed) device = "cuda" sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) if verbose: @@ -628,4 +663,5 @@ def main( main.__doc__ = main_docstring() if __name__ == "__main__": # profiler_runner("asdf.json.gz", fire.Fire, main) + # memory_runner("asdf.pickle", fire.Fire, main) fire.Fire(main) diff --git a/examples/sam2_amg_server/reproduce_experiments.py b/examples/sam2_amg_server/reproduce_experiments.py index 2684cd8111..c6799cd815 100644 --- a/examples/sam2_amg_server/reproduce_experiments.py +++ b/examples/sam2_amg_server/reproduce_experiments.py @@ -89,7 +89,7 @@ def run(task, output_path: Path, kwargs, baseline_folder=None, environ=None): stdout, stderr = run_script_with_args( [ "generate_data.py", - "~/checkpoints/sam2", + f"{str(Path.home())}/checkpoints/sam2", "large", task, image_paths, diff --git a/examples/sam2_amg_server/result.csv b/examples/sam2_amg_server/result.csv index aa43a8703e..0327159727 100644 --- a/examples/sam2_amg_server/result.csv +++ b/examples/sam2_amg_server/result.csv @@ -1,70 +1,70 @@ -p999,task,experiment_name,fourth,total_time,third,bytes_MiB,environ,allow-recompiles,p95,fail_count,torchvision_version,export-model,furious,baseline,max,bytes,fifth,argmax,meta-folder,batch-size,load-exported-model,torch_version,run_script_time,total_img_s,p99,second,total_ms_per_img,miou,num-images,fast,first,gpu-preproc,percentage,points-per-batch,median,mean,batch_size -2374ms,amg,baseline_amg,887ms,935.2057137489319s,947ms,4350,None,,1336ms,,0.22.0.dev20250109+cu124,,,None,2454ms,4561654784,717ms,222,,,,2.7.0.dev20250109+cu124,939.5637674331665,1.0692834584931363img/s,2148ms,1054ms,935.2057137489319ms,,,,1799ms,,4,64,872ms,928ms,1 -950ms,amg,amg_ao,716ms,727.5543773174286s,725ms,4010,None,,824ms,0.0,0.22.0.dev20250109+cu124,,,,1307ms,4205527040,713ms,0,,,,2.7.0.dev20250109+cu124,731.9675371646881,1.3744677115229624img/s,870ms,805ms,727.5543773174286ms,1.0,,,1307ms,,4,64,706ms,721ms,1 -1109ms,amg,amg_ao_ppb_1024_basic,574ms,643.2957496643066s,660ms,33774,None,,749ms,0.0,0.22.0.dev20250109+cu124,,,,1958ms,35415179776,575ms,109,,1,,2.7.0.dev20250109+cu124,647.9796307086945,1.5544949590011028img/s,806ms,615ms,643.2957496643066ms,0.9999994533658028,,,1108ms,,34,1024,622ms,637ms,1 -2781ms,amg,amg_ao_ppb_1024_fast_cold,410ms,877.4602742195129s,518ms,29349,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_inductor_cache_dir'},,546ms,,0.22.0.dev20250109+cu124,,,,427232ms,30775568896,394ms,0,,1,,2.7.0.dev20250109+cu124,886.4245429039001,1.1396527334408206img/s,607ms,2356ms,877.4602742195129ms,,,None,427232ms,,30,1024,423ms,870ms,1 -1392ms,amg,amg_ao_ppb_1024_fast,404ms,455.4250349998474s,440ms,29349,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_inductor_cache_dir'},,548ms,189.0,0.22.0.dev20250109+cu124,,,,8721ms,30775568896,486ms,0,,1,,2.7.0.dev20250109+cu124,460.94617104530334,2.1957510526410458img/s,607ms,1133ms,455.4250349998474ms,0.9936933217227973,,None,8721ms,,30,1024,425ms,448ms,1 -,amg,amg_ao_ppb_1024_save_export,,304.58769369125366s,,1593,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_inductor_cache_dir'},,,,0.22.0.dev20250109+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast,,,,1670930432,,,,1,,2.7.0.dev20250109+cu124,315.2948203086853,0.0img/s,,,,,0,,,,1,1024,,,1 -1061ms,amg,amg_ao_ppb_1024_load_export_cold,565ms,634.6407806873322s,631ms,32958,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_load_export_inductor_cache_dir'},,739ms,186.0,0.22.0.dev20250109+cu124,,,,1770ms,34559617024,680ms,10,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast,2.7.0.dev20250109+cu124,639.0105745792389,1.5756945195311503img/s,822ms,610ms,634.6407806873322ms,0.9945775083007625,,,1061ms,,33,1024,612ms,628ms,1 -1046ms,amg,amg_ao_ppb_1024_load_export,587ms,622.3058869838715s,603ms,32958,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_load_export_inductor_cache_dir'},,720ms,186.0,0.22.0.dev20250109+cu124,,,,1747ms,34559617024,564ms,10,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast,2.7.0.dev20250109+cu124,626.9090824127197,1.606926787799964img/s,759ms,611ms,622.3058869838715ms,0.9945775083007625,,,1045ms,,33,1024,599ms,616ms,1 -1704ms,amg,amg_ao_ppb_1024_load_export_gpu_preproc,603ms,612.9062254428864s,595ms,32982,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_load_export_inductor_cache_dir'},,699ms,772.0,0.22.0.dev20250109+cu124,,,,1730ms,34584782848,629ms,10,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast,2.7.0.dev20250109+cu124,617.6570754051208,1.631570962225746img/s,746ms,678ms,612.9062254428864ms,0.839199618648803,,,1704ms,None,33,1024,594ms,606ms,1 -1505ms,amg,amg_ao_ppb_1024_fast_export_cold,483ms,561.7602450847626s,456ms,28534,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_fast_export_inductor_cache_dir'},,567ms,186.0,0.22.0.dev20250109+cu124,,,,104358ms,29921054720,414ms,0,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast,2.7.0.dev20250109+cu124,567.9983367919922,1.7801188474081369img/s,634ms,1065ms,561.7602450847626ms,0.994521583840068,,None,104358ms,,29,1024,435ms,554ms,1 -1476ms,amg,amg_ao_ppb_1024_fast_export,389ms,446.44090843200684s,424ms,28534,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_fast_export_inductor_cache_dir'},,541ms,186.0,0.22.0.dev20250109+cu124,,,,3661ms,29921054720,380ms,0,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast,2.7.0.dev20250109+cu124,451.4739100933075,2.239938099562174img/s,635ms,742ms,446.44090843200684ms,0.994521583840068,,None,3661ms,,29,1024,421ms,439ms,1 -1432ms,amg,amg_ao_ppb_1024_fast_export_gpu_preproc,378ms,433.64031982421875s,411ms,28631,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_fast_export_inductor_cache_dir'},,513ms,772.0,0.22.0.dev20250109+cu124,,,,4632ms,30022200320,441ms,0,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast,2.7.0.dev20250109+cu124,439.1623215675354,2.306058625741633img/s,572ms,784ms,433.64031982421875ms,0.8391996832205015,,None,4632ms,None,29,1024,408ms,425ms,1 -2751ms,amg,amg_ao_ppb_1024_fast_furious_cold,163ms,841.2357618808746s,157ms,28335,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_furious_inductor_cache_dir'},,258ms,313.0,0.22.0.dev20250109+cu124,,None,,663906ms,29712144384,165ms,0,,1,,2.7.0.dev20250109+cu124,852.4052486419678,1.188727399990881img/s,307ms,2090ms,841.2357618808746ms,0.9721227795145918,,None,663906ms,,29,1024,158ms,833ms,1 -1106ms,amg,amg_ao_ppb_1024_fast_furious,167ms,182.73960876464844s,161ms,28335,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_furious_inductor_cache_dir'},,253ms,313.0,0.22.0.dev20250109+cu124,,None,,8233ms,29712144384,127ms,0,,1,,2.7.0.dev20250109+cu124,188.4141879081726,5.472267379580016img/s,312ms,1099ms,182.73960876464844ms,0.9721227795145918,,None,8233ms,,29,1024,158ms,176ms,1 -,amg,amg_ao_ppb_1024_save_export_furious,,426.2127423286438s,,954,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_furious_inductor_cache_dir'},,,,0.22.0.dev20250109+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast_furious,None,,,1000953344,,,,1,,2.7.0.dev20250109+cu124,434.3983988761902,0.0img/s,,,,,0,,,,0,1024,,,1 -1016ms,amg,amg_ao_ppb_1024_load_export_furious_cold,340ms,349.6220052242279s,332ms,27972,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_load_export_furious_inductor_cache_dir'},,427ms,203.0,0.22.0.dev20250109+cu124,,None,,2024ms,29330775040,302ms,468,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast_furious,2.7.0.dev20250109+cu124,353.6907768249512,2.860231864864044img/s,471ms,344ms,349.6220052242279ms,0.9895564557019261,,,1015ms,,28,1024,332ms,343ms,1 -1041ms,amg,amg_ao_ppb_1024_load_export_furious,301ms,360.9945259094238s,331ms,27972,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_load_export_furious_inductor_cache_dir'},,440ms,203.0,0.22.0.dev20250109+cu124,,None,,1978ms,29330775040,301ms,468,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast_furious,2.7.0.dev20250109+cu124,364.9874835014343,2.7701251077998545img/s,492ms,343ms,360.9945259094238ms,0.9895564557019261,,,1040ms,,28,1024,343ms,355ms,1 -1701ms,amg,amg_ao_ppb_1024_load_export_furious_gpu_preproc,299ms,329.88597416877747s,329ms,28039,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_load_export_furious_inductor_cache_dir'},,399ms,760.0,0.22.0.dev20250109+cu124,,None,,1966ms,29401540096,297ms,468,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast_furious,2.7.0.dev20250109+cu124,334.0973074436188,3.0313504613820785img/s,449ms,340ms,329.88597416877747ms,0.8335056624064843,,,1701ms,None,28,1024,308ms,324ms,1 -1170ms,amg,amg_ao_ppb_1024_fast_export_furious_cold,165ms,450.325879573822s,189ms,27949,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_fast_export_furious_inductor_cache_dir'},,269ms,303.0,0.22.0.dev20250109+cu124,,None,,261209ms,29307650560,164ms,0,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast_furious,2.7.0.dev20250109+cu124,456.4792420864105,2.220614104937466img/s,319ms,770ms,450.325879573822ms,0.9750078081486044,,None,261209ms,,28,1024,170ms,443ms,1 -935ms,amg,amg_ao_ppb_1024_fast_export_furious,166ms,177.67218565940857s,182ms,27949,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_fast_export_furious_inductor_cache_dir'},,253ms,303.0,0.22.0.dev20250109+cu124,,None,,3415ms,29307650560,128ms,0,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast_furious,2.7.0.dev20250109+cu124,183.61352038383484,5.628342986205873img/s,310ms,565ms,177.67218565940857ms,0.9750078081486044,,None,3415ms,,28,1024,157ms,171ms,1 -44632ms,amg,amg_ao_ppb_1024_fast_export_furious_recompiles,115ms,295.7107162475586s,132ms,13255,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_fast_export_furious_inductor_cache_dir'},None,197ms,305.0,0.22.0.dev20250109+cu124,,None,,63790ms,13898889728,168ms,0,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast_furious,2.7.0.dev20250109+cu124,301.4011402130127,3.3816833312284675img/s,237ms,454ms,295.7107162475586ms,0.9750330227313282,,None,63790ms,,13,1024,139ms,289ms,1 -885ms,amg,amg_ao_ppb_1024_fast_export_furious_gpu_preproc,125ms,156.32159233093262s,155ms,27973,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_fast_export_furious_inductor_cache_dir'},,224ms,773.0,0.22.0.dev20250109+cu124,,None,,4151ms,29332738048,120ms,0,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast_furious,2.7.0.dev20250109+cu124,162.26802515983582,6.3970689211187235img/s,275ms,396ms,156.32159233093262ms,0.8382131132391581,,None,4151ms,None,28,1024,132ms,150ms,1 -610ms,amg,amg_ao_ppb_1024_fast_export_furious_gpu_preproc_recompiles,114ms,138.77052688598633s,132ms,13227,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_fast_export_furious_inductor_cache_dir'},None,167ms,774.0,0.22.0.dev20250109+cu124,,None,,4890ms,13870295552,112ms,0,,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/amg_ao_fast_furious,2.7.0.dev20250109+cu124,144.96051049232483,7.206141119732136img/s,197ms,395ms,138.77052688598633ms,0.8381459507926375,,None,4890ms,None,13,1024,118ms,130ms,1 -306ms,sps,baseline_sps,100ms,132.67345762252808s,105ms,1337,None,,194ms,,0.22.0.dev20250109+cu124,,,None,571ms,1402492416,104ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,,,2.7.0.dev20250109+cu124,136.57290863990784,7.537302621939047img/s,276ms,222ms,132.67345762252808ms,,,,571ms,,1,1,113ms,127ms,1 -230ms,sps,sps_ao,98ms,126.97674512863159s,118ms,1339,None,,211ms,0.0,0.22.0.dev20250109+cu124,,,,545ms,1404942848,218ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,,,2.7.0.dev20250109+cu124,131.24220395088196,7.875457816996075img/s,222ms,115ms,126.97674512863158ms,1.0,,,545ms,,1,1,109ms,122ms,1 -232ms,sps,sps_ao_ppb_1_basic,100ms,136.22252011299133s,106ms,1339,None,,218ms,0.0,0.22.0.dev20250109+cu124,,,,638ms,1404942848,112ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,140.56182503700256,7.340930113248078img/s,225ms,117ms,136.22252011299133ms,1.0,,,638ms,,1,1,111ms,131ms,1 -3133ms,sps,sps_ao_ppb_1_fast_cold,91ms,524.464339017868s,97ms,1593,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_inductor_cache_dir'},,190ms,,0.22.0.dev20250109+cu124,,,,401201ms,1670930432,96ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,535.5261473655701,1.9067073308981088img/s,210ms,2734ms,524.464339017868ms,,,None,401201ms,,1,1,100ms,515ms,1 -779ms,sps,sps_ao_ppb_1_fast,212ms,132.37645173072815s,202ms,1302,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_inductor_cache_dir'},,206ms,0.0,0.22.0.dev20250109+cu124,,,,8140ms,1366200320,208ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,138.50028347969055,7.5542136605545img/s,213ms,772ms,132.37645173072815ms,0.9998687426447869,,None,8140ms,,1,1,101ms,126ms,1 -,sps,sps_ao_ppb_1_save_export,,272.5903356075287s,,1593,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_inductor_cache_dir'},,,,0.22.0.dev20250109+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast,,,,1670930432,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,283.19432258605957,0.0img/s,,,,,0,,,,1,1,,,1 -226ms,sps,sps_ao_ppb_1_load_export_cold,213ms,161.28311896324158s,211ms,5949,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_load_export_inductor_cache_dir'},,216ms,0.0,0.22.0.dev20250109+cu124,,,,707ms,6238665728,185ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast,2.7.0.dev20250109+cu124,165.69491052627563,6.2002769194208875img/s,221ms,225ms,161.28311896324158ms,0.999868677020073,,,707ms,,6,1,139ms,155ms,1 -245ms,sps,sps_ao_ppb_1_load_export,93ms,131.32559871673584s,98ms,5949,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_load_export_inductor_cache_dir'},,211ms,0.0,0.22.0.dev20250109+cu124,,,,597ms,6238665728,98ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast,2.7.0.dev20250109+cu124,136.12982988357544,7.614661648388603img/s,220ms,134ms,131.32559871673584ms,0.999868677020073,,,597ms,,6,1,104ms,125ms,1 -196ms,sps,sps_ao_ppb_1_load_export_gpu_preproc,159ms,117.73162794113159s,164ms,5971,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_load_export_inductor_cache_dir'},,162ms,0.0,0.22.0.dev20250109+cu124,,,,1361ms,6261886976,164ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast,2.7.0.dev20250109+cu124,122.47605919837952,8.493894270280727img/s,171ms,139ms,117.73162794113159ms,0.9861222158936289,,,1361ms,None,6,1,101ms,111ms,1 -228ms,sps,sps_ao_ppb_1_fast_export_cold,92ms,120.34239029884338s,96ms,5949,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_fast_export_inductor_cache_dir'},,203ms,0.0,0.22.0.dev20250109+cu124,,,,541ms,6238665728,97ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast,2.7.0.dev20250109+cu124,124.82643246650696,8.309623878308582img/s,215ms,155ms,120.34239029884338ms,0.999868677020073,,None,541ms,,6,1,101ms,114ms,1 -229ms,sps,sps_ao_ppb_1_fast_export,135ms,120.78508996963501s,96ms,5949,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_fast_export_inductor_cache_dir'},,203ms,0.0,0.22.0.dev20250109+cu124,,,,570ms,6238665728,116ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast,2.7.0.dev20250109+cu124,124.93209862709045,8.279167571522253img/s,212ms,106ms,120.78508996963501ms,0.999868677020073,,None,570ms,,6,1,102ms,115ms,1 -184ms,sps,sps_ao_ppb_1_fast_export_gpu_preproc,92ms,120.33534979820251s,94ms,5971,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_fast_export_inductor_cache_dir'},,164ms,0.0,0.22.0.dev20250109+cu124,,,,1240ms,6261886976,93ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast,2.7.0.dev20250109+cu124,124.94753289222717,8.310110052257789img/s,169ms,108ms,120.33534979820251ms,0.9861222158936289,,None,1240ms,None,6,1,97ms,114ms,1 -2368ms,sps,sps_ao_ppb_1_fast_furious_cold,19ms,581.2481288909912s,24ms,954,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_furious_inductor_cache_dir'},,70ms,0.0,0.22.0.dev20250109+cu124,,None,,532242ms,1000953344,35ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,592.1693325042725,1.7204356458023844img/s,74ms,1838ms,581.2481288909912ms,0.9996674702763557,,None,532242ms,,0,1,35ms,574ms,1 -614ms,sps,sps_ao_ppb_1_fast_furious,53ms,45.71470355987549s,25ms,861,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_furious_inductor_cache_dir'},,60ms,0.0,0.22.0.dev20250109+cu124,,None,,8026ms,903450624,23ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,51.57617497444153,21.874800056184018img/s,68ms,606ms,45.71470355987549ms,0.9996674702763557,,None,8026ms,,0,1,29ms,40ms,1 -,sps,sps_ao_ppb_1_save_export_furious,,364.1186008453369s,,954,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_furious_inductor_cache_dir'},,,,0.22.0.dev20250109+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast_furious,None,,,1000953344,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,372.80925726890564,0.0img/s,,,,,0,,,,0,1,,,1 -78ms,sps,sps_ao_ppb_1_load_export_furious_cold,50ms,53.28082203865051s,43ms,1790,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_load_export_furious_inductor_cache_dir'},,69ms,0.0,0.22.0.dev20250109+cu124,,None,,939ms,1877512192,24ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast_furious,2.7.0.dev20250109+cu124,57.669695138931274,18.76847919640933img/s,74ms,73ms,53.28082203865051ms,0.9998199329972267,,,939ms,,1,1,48ms,47ms,1 -80ms,sps,sps_ao_ppb_1_load_export_furious,21ms,50.997873306274414s,24ms,1790,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_load_export_furious_inductor_cache_dir'},,70ms,0.0,0.22.0.dev20250109+cu124,,None,,861ms,1877512192,24ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast_furious,2.7.0.dev20250109+cu124,55.45322823524475,19.60866081599852img/s,74ms,33ms,50.997873306274414ms,0.9998199329972267,,,861ms,,1,1,42ms,45ms,1 -29ms,sps,sps_ao_ppb_1_load_export_furious_gpu_preproc,17ms,24.790576696395874s,18ms,1814,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_load_export_furious_inductor_cache_dir'},,19ms,0.0,0.22.0.dev20250109+cu124,,None,,1612ms,1902484480,18ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast_furious,2.7.0.dev20250109+cu124,29.53805947303772,40.33790791746216img/s,19ms,27ms,24.790576696395874ms,0.9860970453268383,,,1612ms,None,1,1,17ms,19ms,1 -82ms,sps,sps_ao_ppb_1_fast_export_furious_cold,20ms,39.87857627868652s,36ms,1790,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_fast_export_furious_inductor_cache_dir'},,61ms,0.0,0.22.0.dev20250109+cu124,,None,,866ms,1877512192,25ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast_furious,2.7.0.dev20250109+cu124,44.19964957237244,25.076120897888206img/s,71ms,35ms,39.87857627868652ms,0.9998199329972267,,None,866ms,,1,1,31ms,34ms,1 -75ms,sps,sps_ao_ppb_1_fast_export_furious,20ms,40.75656461715698s,24ms,1790,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_fast_export_furious_inductor_cache_dir'},,64ms,0.0,0.22.0.dev20250109+cu124,,None,,865ms,1877512192,26ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast_furious,2.7.0.dev20250109+cu124,45.36444664001465,24.53592468829028img/s,70ms,34ms,40.75656461715698ms,0.9998199329972267,,None,865ms,,1,1,31ms,35ms,1 -93ms,sps,sps_ao_ppb_1_fast_export_furious_recompiles,21ms,49.636521339416504s,25ms,1790,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_fast_export_furious_inductor_cache_dir'},None,66ms,0.0,0.22.0.dev20250109+cu124,,None,,9723ms,1877512192,25ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast_furious,2.7.0.dev20250109+cu124,55.89960026741028,20.146456137849796img/s,73ms,37ms,49.636521339416504ms,0.24249802377738716,,None,9723ms,,1,1,31ms,44ms,1 -29ms,sps,sps_ao_ppb_1_fast_export_furious_gpu_preproc,17ms,24.562424421310425s,19ms,1814,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_fast_export_furious_inductor_cache_dir'},,19ms,0.0,0.22.0.dev20250109+cu124,,None,,1566ms,1902484480,18ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast_furious,2.7.0.dev20250109+cu124,29.499178171157837,40.71259346583057img/s,19ms,27ms,24.562424421310425ms,0.9860970453268383,,None,1566ms,None,1,1,17ms,19ms,1 -32ms,sps,sps_ao_ppb_1_fast_export_furious_gpu_preproc_recompiles,17ms,26.11998414993286s,19ms,1814,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/sps_fast_export_furious_inductor_cache_dir'},None,19ms,0.0,0.22.0.dev20250109+cu124,,None,,3477ms,1902484480,18ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/sps_ao_fast_furious,2.7.0.dev20250109+cu124,32.0809326171875,38.284862435591116img/s,20ms,29ms,26.11998414993286ms,0.18694353939804045,,None,3477ms,None,1,1,17ms,21ms,1 -1614ms,mps,baseline_mps,217ms,339.7126615047455s,368ms,1337,None,,738ms,,0.22.0.dev20250109+cu124,,,None,1837ms,1402492416,510ms,126,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,,,2.7.0.dev20250109+cu124,344.3770024776459,2.943664200122935img/s,1304ms,490ms,339.7126615047455ms,,,,579ms,,1,,263ms,332ms,1 -385ms,mps,mps_ao,104ms,139.90302205085754s,118ms,8022,None,,215ms,0.0,0.22.0.dev20250109+cu124,,,,600ms,8411699712,150ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,,,2.7.0.dev20250109+cu124,144.1774024963379,7.147808427158064img/s,237ms,132ms,139.90302205085754ms,0.999999164044857,,,600ms,,8,,121ms,133ms,1 -295ms,mps,mps_ao_ppb_None_basic,216ms,180.09048891067505s,231ms,8022,None,,236ms,0.0,0.22.0.dev20250109+cu124,,,,622ms,8411699712,246ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,184.8732569217682,5.55276409125637img/s,263ms,236ms,180.09048891067505ms,0.999999164044857,,,622ms,,8,,162ms,171ms,1 -43126ms,mps,mps_ao_ppb_None_fast_cold,93ms,531.2832531929016s,104ms,8021,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_inductor_cache_dir'},,208ms,,0.22.0.dev20250109+cu124,,,,331945ms,8411176448,110ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,543.5350062847137,1.8822351240890964img/s,224ms,1009ms,531.2832531929016ms,,,None,331945ms,,8,,107ms,524ms,1 -1451ms,mps,mps_ao_ppb_None_fast,95ms,177.8515875339508s,109ms,8021,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_inductor_cache_dir'},,226ms,0.0,0.22.0.dev20250109+cu124,,,,8897ms,8411176448,147ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,183.4075665473938,5.622665582386809img/s,248ms,581ms,177.8515875339508ms,0.9983835342526436,,None,8897ms,,8,,146ms,170ms,1 -,mps,mps_ao_ppb_None_save_export,,262.2255263328552s,,1593,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_inductor_cache_dir'},,,,0.22.0.dev20250109+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast,,,,1670930432,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,270.12541913986206,0.0img/s,,,,,0,,,,1,,,,1 -333ms,mps,mps_ao_ppb_None_load_export_cold,97ms,138.29926824569702s,111ms,7206,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_load_export_inductor_cache_dir'},,220ms,0.0,0.22.0.dev20250109+cu124,,,,649ms,7556661248,120ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast,2.7.0.dev20250109+cu124,142.37936091423035,7.230696247961626img/s,234ms,125ms,138.29926824569702ms,0.9983786268234253,,,649ms,,7,,114ms,131ms,1 -320ms,mps,mps_ao_ppb_None_load_export,96ms,132.98988270759583s,109ms,7206,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_load_export_inductor_cache_dir'},,212ms,0.0,0.22.0.dev20250109+cu124,,,,543ms,7556661248,118ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast,2.7.0.dev20250109+cu124,137.46344566345215,7.519368989885455img/s,235ms,185ms,132.98988270759583ms,0.9983786268234253,,,543ms,,7,,112ms,125ms,1 -369ms,mps,mps_ao_ppb_None_load_export_gpu_preproc,95ms,153.9310953617096s,179ms,7230,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_load_export_inductor_cache_dir'},,184ms,0.0,0.22.0.dev20250109+cu124,,,,1217ms,7581827072,127ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast,2.7.0.dev20250109+cu124,159.28356790542603,6.496413201310528img/s,202ms,139ms,153.9310953617096ms,0.9224205894982442,,,1217ms,None,7,,153ms,145ms,1 -37104ms,mps,mps_ao_ppb_None_fast_export_cold,96ms,236.0241584777832s,107ms,7206,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_fast_export_inductor_cache_dir'},,206ms,0.0,0.22.0.dev20250109+cu124,,,,39205ms,7556661248,113ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast,2.7.0.dev20250109+cu124,244.1103572845459,4.23685442392597img/s,229ms,119ms,236.0241584777832ms,0.9983784531950951,,None,39205ms,,7,,109ms,227ms,1 -1280ms,mps,mps_ao_ppb_None_fast_export,103ms,132.519935131073s,176ms,7206,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_fast_export_inductor_cache_dir'},,203ms,0.0,0.22.0.dev20250109+cu124,,,,3634ms,7556661248,155ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast,2.7.0.dev20250109+cu124,137.68328261375427,7.54603448161153img/s,223ms,223ms,132.519935131073ms,0.9983784534335136,,None,3634ms,,7,,109ms,125ms,1 -1267ms,mps,mps_ao_ppb_None_fast_export_gpu_preproc,157ms,147.0070924758911s,181ms,7230,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_fast_export_inductor_cache_dir'},,175ms,0.0,0.22.0.dev20250109+cu124,,,,3928ms,7581827072,118ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast,2.7.0.dev20250109+cu124,152.5612542629242,6.80239288566297img/s,195ms,185ms,147.0070924758911ms,0.9224205495780334,,None,3928ms,None,7,,131ms,139ms,1 -44108ms,mps,mps_ao_ppb_None_fast_furious_cold,22ms,604.3798043727875s,30ms,4222,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_furious_inductor_cache_dir'},,69ms,0.0,0.22.0.dev20250109+cu124,,None,,488223ms,4427842560,69ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,616.8908636569977,1.654588708565103img/s,80ms,1530ms,604.3798043727875ms,0.9972913320064545,,None,488223ms,,4,,33ms,597ms,1 -1341ms,mps,mps_ao_ppb_None_fast_furious,59ms,78.28538370132446s,66ms,4222,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_furious_inductor_cache_dir'},,79ms,0.0,0.22.0.dev20250109+cu124,,None,,9623ms,4427842560,73ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,84.57566738128662,12.773776568755345img/s,89ms,551ms,78.28538370132446ms,0.9972910861372948,,None,9623ms,,4,,61ms,70ms,1 -,mps,mps_ao_ppb_None_save_export_furious,,349.34193754196167s,,954,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_furious_inductor_cache_dir'},,,,0.22.0.dev20250109+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast_furious,None,,,1000953344,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,,2.7.0.dev20250109+cu124,360.5604326725006,0.0img/s,,,,,0,,,,0,,,,1 -309ms,mps,mps_ao_ppb_None_load_export_furious_cold,34ms,56.33559775352478s,41ms,3813,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_load_export_furious_inductor_cache_dir'},,80ms,0.0,0.22.0.dev20250109+cu124,,None,,765ms,3998387200,43ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast_furious,2.7.0.dev20250109+cu124,60.93665313720703,17.75076576581514img/s,88ms,54ms,56.33559775352478ms,0.9961582001447677,,,765ms,,3,,44ms,49ms,1 -353ms,mps,mps_ao_ppb_None_load_export_furious,33ms,56.61087965965271s,40ms,3813,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_load_export_furious_inductor_cache_dir'},,80ms,0.0,0.22.0.dev20250109+cu124,,None,,845ms,3998387200,40ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast_furious,2.7.0.dev20250109+cu124,61.454379081726074,17.664449060181493img/s,88ms,85ms,56.61087965965271ms,0.9961582001447677,,,845ms,,3,,44ms,49ms,1 -322ms,mps,mps_ao_ppb_None_load_export_furious_gpu_preproc,29ms,40.086507081985474s,33ms,3837,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_load_export_furious_inductor_cache_dir'},,39ms,0.0,0.22.0.dev20250109+cu124,,None,,1539ms,4023553024,33ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast_furious,2.7.0.dev20250109+cu124,44.91008281707764,24.94604975072501img/s,49ms,49ms,40.086507081985474ms,0.9239367794789141,,,1539ms,None,3,,30ms,33ms,1 -32689ms,mps,mps_ao_ppb_None_fast_export_furious_cold,60ms,157.29275488853455s,67ms,3813,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_fast_export_furious_inductor_cache_dir'},,74ms,0.0,0.22.0.dev20250109+cu124,,None,,45808ms,3998387200,55ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast_furious,2.7.0.dev20250109+cu124,165.38462448120117,6.35757190919982img/s,89ms,78ms,157.29275488853455ms,0.9969035378098487,,None,45808ms,,3,,38ms,147ms,1 -1401ms,mps,mps_ao_ppb_None_fast_export_furious,60ms,50.659629821777344s,68ms,3813,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_fast_export_furious_inductor_cache_dir'},,70ms,0.0,0.22.0.dev20250109+cu124,,None,,3938ms,3998387200,70ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast_furious,2.7.0.dev20250109+cu124,56.82898807525635,19.73958363924176img/s,80ms,77ms,50.659629821777344ms,0.9969037767052651,,None,3938ms,,3,,33ms,43ms,1 -8305ms,mps,mps_ao_ppb_None_fast_export_furious_recompiles,21ms,65.21127843856812s,28ms,3813,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_fast_export_furious_inductor_cache_dir'},None,63ms,0.0,0.22.0.dev20250109+cu124,,None,,13909ms,3998387200,54ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast_furious,2.7.0.dev20250109+cu124,71.5342059135437,15.334770670721383img/s,77ms,38ms,65.21127843856812ms,0.9963943874835968,,None,13909ms,,3,,33ms,58ms,1 -1311ms,mps,mps_ao_ppb_None_fast_export_furious_gpu_preproc,19ms,33.9236855506897s,24ms,3837,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_fast_export_furious_inductor_cache_dir'},,30ms,0.0,0.22.0.dev20250109+cu124,,None,,4556ms,4023553024,26ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast_furious,2.7.0.dev20250109+cu124,40.050333738327026,29.47792917446345img/s,38ms,31ms,33.9236855506897ms,0.9237591220784234,,None,4556ms,None,3,,20ms,27ms,1 -1649ms,mps,mps_ao_ppb_None_fast_export_furious_gpu_preproc_recompiles,18ms,34.80714464187622s,23ms,3837,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/mps_fast_export_furious_inductor_cache_dir'},None,28ms,0.0,0.22.0.dev20250109+cu124,,None,,5661ms,4023553024,25ms,0,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/amg_baseline_annotations,1,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_11/exported_models/mps_ao_fast_furious,2.7.0.dev20250109+cu124,41.254807472229004,28.729733802895954img/s,34ms,31ms,34.80714464187622ms,0.9227598560500192,,None,5661ms,None,3,,20ms,28ms,1 +furious,fast,points-per-batch,bytes,argmax,p95,p999,p99,miou,fourth,total_time,torch_version,total_img_s,batch-size,second,experiment_name,run_script_time,mean,batch_size,percentage,third,task,num-images,fifth,environ,fail_count,allow-recompiles,max,load-exported-model,torchvision_version,median,total_ms_per_img,gpu-preproc,meta-folder,bytes_MiB,first,baseline,export-model +,,64,4561654784,468,1323ms,2363ms,2086ms,,892ms,927.4758312702179s,2.7.0.dev20250201+cu124,1.0781952114379705img/s,,1046ms,baseline_amg,931.3759133815765,921ms,1,4,955ms,amg,,724ms,None,,,2466ms,,0.22.0.dev20250201+cu124,869ms,927.4758312702179ms,,,4350,1733ms,None, +,,64,4205527040,0,815ms,904ms,857ms,1.0,660ms,718.6690595149994s,2.7.0.dev20250201+cu124,1.3914610442181266img/s,,748ms,amg_ao,723.3117945194244,713ms,1,4,673ms,amg,,760ms,None,0.0,,1263ms,,0.22.0.dev20250201+cu124,697ms,718.6690595149994ms,,,4010,1263ms,, +,,1024,35427762688,109,745ms,1006ms,791ms,0.9999994533658028,577ms,631.6344785690308s,2.7.0.dev20250201+cu124,1.5831941319376708img/s,1,619ms,amg_ao_ppb_1024_basic,635.8103907108307,626ms,1,34,594ms,amg,,609ms,None,0.0,,1947ms,,0.22.0.dev20250201+cu124,610ms,631.6344785690308ms,,,33786,1005ms,, +,None,1024,30775568896,0,576ms,3526ms,644ms,,501ms,849.2408077716827s,2.7.0.dev20250201+cu124,1.1775223126923131img/s,1,3157ms,amg_ao_ppb_1024_fast_cold,861.5647690296173,841ms,1,30,421ms,amg,,501ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_inductor_cache_dir'},,,372124ms,,0.22.0.dev20250201+cu124,466ms,849.2408077716827ms,,,29349,372124ms,, +,None,1024,30775568896,0,541ms,1512ms,617ms,0.9937346105006776,386ms,452.082448720932s,2.7.0.dev20250201+cu124,2.2119858951155487img/s,1,1000ms,amg_ao_ppb_1024_fast,458.1768579483032,446ms,1,30,448ms,amg,,392ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_inductor_cache_dir'},191.0,,8411ms,,0.22.0.dev20250201+cu124,422ms,452.082448720932ms,,,29349,8411ms,, +,,1024,18221665280,,,,,,,356.0369083881378s,2.7.0.dev20250201+cu124,0.0img/s,1,,amg_ao_ppb_1024_save_export,367.34787678718567,,1,17,,amg,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,,17377,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast +,,1024,49836364288,837,559ms,1592ms,639ms,0.993709121615135,397ms,460.2203013896942s,2.7.0.dev20250201+cu124,2.1728724199701137img/s,1,493ms,amg_ao_ppb_1024_load_export_cold,464.4886541366577,453ms,1,48,443ms,amg,,510ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_inductor_cache_dir'},188.0,,1760ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,436ms,460.2203013896942ms,,,47527,961ms,, +,,1024,49836364288,837,592ms,1691ms,649ms,0.993709121615135,445ms,478.4169816970825s,2.7.0.dev20250201+cu124,2.09022680685939img/s,1,431ms,amg_ao_ppb_1024_load_export,483.0541400909424,472ms,1,48,429ms,amg,,508ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_inductor_cache_dir'},188.0,,1737ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,462ms,478.4169816970825ms,,,47527,763ms,, +,,1024,49861530112,837,565ms,1670ms,622ms,0.9937652501226203,398ms,465.69065976142883s,2.7.0.dev20250201+cu124,2.1473482000096276img/s,1,435ms,amg_ao_ppb_1024_load_export_gpu_preproc,469.74300265312195,460ms,1,48,427ms,amg,,397ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_inductor_cache_dir'},185.0,,1735ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,452ms,465.69065976142883ms,None,,47551,776ms,, +,None,1024,49836364288,837,546ms,1611ms,608ms,0.993709121615135,415ms,454.15750002861023s,2.7.0.dev20250201+cu124,2.201879303847242img/s,1,438ms,amg_ao_ppb_1024_fast_export_cold,458.17887783050537,448ms,1,48,545ms,amg,,421ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_inductor_cache_dir'},188.0,,1730ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,430ms,454.15750002861023ms,,,47527,943ms,, +,None,1024,49836364288,837,577ms,1702ms,643ms,0.993709121615135,402ms,473.2662968635559s,2.7.0.dev20250201+cu124,2.112975309307316img/s,1,432ms,amg_ao_ppb_1024_fast_export,477.25709891319275,467ms,1,48,427ms,amg,,486ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_inductor_cache_dir'},188.0,,1742ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,451ms,473.2662968635559ms,,,47527,754ms,, +,None,1024,49861530112,837,543ms,1597ms,596ms,0.9937652501226203,396ms,450.6334979534149s,2.7.0.dev20250201+cu124,2.219098235132482img/s,1,433ms,amg_ao_ppb_1024_fast_export_gpu_preproc,454.61152243614197,445ms,1,48,426ms,amg,,395ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_inductor_cache_dir'},185.0,,1766ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,430ms,450.6334979534149ms,None,,47551,764ms,, +None,None,1024,29712131072,0,275ms,2880ms,333ms,0.9736336072679046,169ms,994.9303135871887s,2.7.0.dev20250201+cu124,1.0050955190967423img/s,1,2081ms,amg_ao_ppb_1024_fast_furious_cold,1006.4958641529083,987ms,1,29,192ms,amg,,143ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_furious_inductor_cache_dir'},305.0,,800771ms,,0.22.0.dev20250201+cu124,174ms,994.9303135871887ms,,,28335,800771ms,, +None,None,1024,29712131072,0,274ms,933ms,334ms,0.9736336072679046,163ms,192.62348794937134s,2.7.0.dev20250201+cu124,5.191474885258216img/s,1,699ms,amg_ao_ppb_1024_fast_furious,198.63731622695923,186ms,1,29,179ms,amg,,130ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_furious_inductor_cache_dir'},305.0,,10094ms,,0.22.0.dev20250201+cu124,165ms,192.62348794937134ms,,,28335,10094ms,, +None,,1024,9179703808,,,,,,,519.6249597072601s,2.7.0.dev20250201+cu124,0.0img/s,1,,amg_ao_ppb_1024_save_export_furious,529.3503592014313,,1,8,,amg,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_furious_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,,8754,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious +None,,1024,29307644416,468,259ms,906ms,309ms,0.971583874842335,166ms,178.88770842552185s,2.7.0.dev20250201+cu124,5.590099000101732img/s,1,202ms,amg_ao_ppb_1024_load_export_furious_cold,183.20707321166992,169ms,1,28,198ms,amg,,169ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_furious_inductor_cache_dir'},308.0,,1468ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,158ms,178.88770842552185ms,,,27949,906ms,, +None,,1024,29307644416,468,258ms,716ms,299ms,0.971583874842335,167ms,173.60630631446838s,2.7.0.dev20250201+cu124,5.760159416033033img/s,1,164ms,amg_ao_ppb_1024_load_export_furious,177.37090826034546,168ms,1,28,156ms,amg,,125ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_furious_inductor_cache_dir'},308.0,,1468ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,157ms,173.60630631446838ms,,,27949,716ms,, +None,,1024,29308632576,468,232ms,679ms,282ms,0.9707489542138409,126ms,156.5510959625244s,2.7.0.dev20250201+cu124,6.387690829321198img/s,1,160ms,amg_ao_ppb_1024_load_export_furious_gpu_preproc,160.46401953697205,151ms,1,28,155ms,amg,,126ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_furious_inductor_cache_dir'},290.0,,1467ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,136ms,156.5510959625244ms,None,,27950,678ms,, +None,None,1024,29307644416,468,268ms,750ms,320ms,0.971583874842335,159ms,182.61804270744324s,2.7.0.dev20250201+cu124,5.4759101848551435img/s,1,162ms,amg_ao_ppb_1024_fast_export_furious_cold,187.25734424591064,177ms,1,28,158ms,amg,,149ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_furious_inductor_cache_dir'},308.0,,1466ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,165ms,182.61804270744324ms,,,27949,750ms,, +None,None,1024,29307644416,468,259ms,700ms,308ms,0.971583874842335,134ms,178.3385353088379s,2.7.0.dev20250201+cu124,5.607313070437913img/s,1,160ms,amg_ao_ppb_1024_fast_export_furious,182.3735547065735,173ms,1,28,157ms,amg,,162ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_furious_inductor_cache_dir'},308.0,,1507ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,163ms,178.3385353088379ms,,,27949,700ms,, +None,None,1024,16525926912,0,201ms,36421ms,227ms,0.9716291864482343,141ms,245.76354837417603s,2.7.0.dev20250201+cu124,4.068951667630937img/s,1,137ms,amg_ao_ppb_1024_fast_export_furious_recompiles,251.90375113487244,240ms,1,16,131ms,amg,,128ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_furious_inductor_cache_dir'},311.0,None,49208ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,140ms,245.76354837417603ms,,,15760,49208ms,, +None,None,1024,29308632576,468,233ms,774ms,283ms,0.9707489542138409,127ms,157.9279761314392s,2.7.0.dev20250201+cu124,6.3320003491194425img/s,1,163ms,amg_ao_ppb_1024_fast_export_furious_gpu_preproc,162.7095422744751,152ms,1,28,157ms,amg,,129ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_furious_inductor_cache_dir'},290.0,,1464ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,137ms,157.9279761314392ms,None,,27950,773ms,, +None,None,1024,16551092736,0,174ms,308ms,203ms,0.9708677416053486,115ms,137.26364755630493s,2.7.0.dev20250201+cu124,7.28525008480344img/s,1,135ms,amg_ao_ppb_1024_fast_export_furious_gpu_preproc_recompiles,142.44125938415527,130ms,1,16,135ms,amg,,116ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_furious_inductor_cache_dir'},293.0,None,2189ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,121ms,137.26364755630493ms,None,,15784,2189ms,, +,,1,1402492416,0,214ms,316ms,281ms,,100ms,136.17227387428284s,2.7.0.dev20250201+cu124,7.343638844741783img/s,,118ms,baseline_sps,140.2417643070221,131ms,1,1,105ms,sps,,227ms,None,,,532ms,,0.22.0.dev20250201+cu124,115ms,136.17227387428284ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1337,532ms,None, +,,1,1404942848,0,205ms,229ms,219ms,1.0,105ms,127.24607348442078s,2.7.0.dev20250201+cu124,7.858788665274091img/s,,105ms,sps_ao,131.5206482410431,122ms,1,1,102ms,sps,,225ms,None,0.0,,579ms,,0.22.0.dev20250201+cu124,110ms,127.24607348442076ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1339,579ms,, +,,1,1404989952,0,203ms,256ms,218ms,1.0,106ms,124.8940806388855s,2.7.0.dev20250201+cu124,8.006784588065194img/s,1,104ms,sps_ao_ppb_1_basic,128.7957148551941,120ms,1,1,102ms,sps,,217ms,None,0.0,,583ms,,0.22.0.dev20250201+cu124,109ms,124.8940806388855ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1339,583ms,, +,None,1,1408784896,0,216ms,3260ms,223ms,,201ms,488.7042841911316s,2.7.0.dev20250201+cu124,2.046227201906217img/s,1,2959ms,sps_ao_ppb_1_fast_cold,496.82423877716064,483ms,1,1,212ms,sps,,209ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_inductor_cache_dir'},,,304090ms,,0.22.0.dev20250201+cu124,203ms,488.7042841911316ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1343,304090ms,, +,None,1,1366200320,0,217ms,775ms,222ms,0.9998691322207451,122ms,196.3028929233551s,2.7.0.dev20250201+cu124,5.0941684307752img/s,1,768ms,sps_ao_ppb_1_fast,202.54180693626404,189ms,1,1,195ms,sps,,208ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_inductor_cache_dir'},0.0,,8209ms,,0.22.0.dev20250201+cu124,205ms,196.3028929233551ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1302,8209ms,, +,,1,1390578176,,,,,,,307.4514627456665s,2.7.0.dev20250201+cu124,0.0img/s,1,,sps_ao_ppb_1_save_export,316.7780604362488,,1,1,,sps,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1326,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast +,,1,6238665728,0,215ms,233ms,221ms,0.9998687437176704,202ms,160.5826907157898s,2.7.0.dev20250201+cu124,6.227321235822784img/s,1,221ms,sps_ao_ppb_1_load_export_cold,165.16510462760925,153ms,1,6,198ms,sps,,214ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_inductor_cache_dir'},0.0,,576ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,138ms,160.5826907157898ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5949,576ms,, +,,1,6238665728,0,213ms,294ms,220ms,0.9998687437176704,210ms,130.84592247009277s,2.7.0.dev20250201+cu124,7.642576712534304img/s,1,108ms,sps_ao_ppb_1_load_export,135.52789616584778,125ms,1,6,144ms,sps,,140ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_inductor_cache_dir'},0.0,,434ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,104ms,130.84592247009277ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5949,434ms,, +,,1,6261886976,0,165ms,180ms,175ms,0.999868236720562,100ms,118.1360731124878s,2.7.0.dev20250201+cu124,8.46481496847971img/s,1,103ms,sps_ao_ppb_1_load_export_gpu_preproc,122.45444965362549,112ms,1,6,103ms,sps,,98ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_inductor_cache_dir'},0.0,,488ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,103ms,118.1360731124878ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5971,488ms,, +,None,1,6238665728,0,206ms,226ms,216ms,0.9998687437176704,92ms,124.29203748703003s,2.7.0.dev20250201+cu124,8.045567682518286img/s,1,121ms,sps_ao_ppb_1_fast_export_cold,128.70573449134827,118ms,1,6,135ms,sps,,96ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_inductor_cache_dir'},0.0,,430ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,104ms,124.29203748703003ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5949,430ms,, +,None,1,6238665728,0,200ms,226ms,216ms,0.9998687437176704,99ms,121.70427465438843s,2.7.0.dev20250201+cu124,8.216638263855277img/s,1,99ms,sps_ao_ppb_1_fast_export,126.40637016296387,115ms,1,6,96ms,sps,,105ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_inductor_cache_dir'},0.0,,474ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,103ms,121.70427465438843ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5949,474ms,, +,None,1,6261886976,0,168ms,189ms,178ms,0.999868236720562,93ms,122.82635688781738s,2.7.0.dev20250201+cu124,8.141575027852884img/s,1,107ms,sps_ao_ppb_1_fast_export_gpu_preproc,127.55544590950012,117ms,1,6,98ms,sps,,172ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_inductor_cache_dir'},0.0,,481ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,104ms,122.82635688781738ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5971,481ms,, +None,None,1,903450624,0,66ms,2448ms,71ms,0.9996802344322204,18ms,598.2366213798523s,2.7.0.dev20250201+cu124,1.6715793788977134img/s,1,1896ms,sps_ao_ppb_1_fast_furious_cold,606.6854190826416,590ms,1,0,24ms,sps,,30ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_furious_inductor_cache_dir'},0.0,,553957ms,,0.22.0.dev20250201+cu124,30ms,598.2366213798523ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,861,553957ms,, +None,None,1,903450624,0,60ms,922ms,68ms,0.9996802344322204,19ms,46.42959976196289s,2.7.0.dev20250201+cu124,21.537984499690705img/s,1,914ms,sps_ao_ppb_1_fast_furious,52.85066604614258,40ms,1,0,27ms,sps,,52ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_furious_inductor_cache_dir'},0.0,,8831ms,,0.22.0.dev20250201+cu124,28ms,46.42959976196289ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,861,8831ms,, +None,,1,903450624,,,,,,,395.61680269241333s,2.7.0.dev20250201+cu124,0.0img/s,1,,sps_ao_ppb_1_save_export_furious,405.58058881759644,,1,0,,sps,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_furious_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,861,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious +None,,1,1768025088,0,63ms,78ms,70ms,0.9996752961277962,31ms,40.04996109008789s,2.7.0.dev20250201+cu124,24.968813271768536img/s,1,41ms,sps_ao_ppb_1_load_export_furious_cold,44.494996547698975,33ms,1,1,54ms,sps,,58ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_furious_inductor_cache_dir'},0.0,,688ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,29ms,40.04996109008789ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1686,688ms,, +None,,1,1768025088,0,67ms,98ms,73ms,0.9996752961277962,54ms,41.31868815422058s,2.7.0.dev20250201+cu124,24.20212365570597img/s,1,24ms,sps_ao_ppb_1_load_export_furious,45.522459983825684,36ms,1,1,24ms,sps,,24ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_furious_inductor_cache_dir'},0.0,,769ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,31ms,41.31868815422058ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1686,769ms,, +None,,1,1794153472,0,28ms,33ms,30ms,0.9996936089992523,18ms,30.337790489196777s,2.7.0.dev20250201+cu124,32.96218952913192img/s,1,21ms,sps_ao_ppb_1_load_export_furious_gpu_preproc,35.1632604598999,22ms,1,1,22ms,sps,,22ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_furious_inductor_cache_dir'},0.0,,720ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,20ms,30.337790489196777ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1711,720ms,, +None,None,1,1768025088,0,59ms,82ms,69ms,0.9996752961277962,37ms,36.78891086578369s,2.7.0.dev20250201+cu124,27.182103967368906img/s,1,39ms,sps_ao_ppb_1_fast_export_furious_cold,40.70477890968323,31ms,1,1,53ms,sps,,35ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_furious_inductor_cache_dir'},0.0,,752ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,28ms,36.78891086578369ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1686,752ms,, +None,None,1,1768025088,0,62ms,74ms,69ms,0.9996752961277962,45ms,37.20629072189331s,2.7.0.dev20250201+cu124,26.877175353886315img/s,1,39ms,sps_ao_ppb_1_fast_export_furious,41.312560081481934,32ms,1,1,22ms,sps,,23ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_furious_inductor_cache_dir'},0.0,,678ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,29ms,37.20629072189331ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1686,678ms,, +None,None,1,1768025088,0,58ms,82ms,68ms,0.24502152660781712,19ms,44.12568783760071s,2.7.0.dev20250201+cu124,22.662536246015694img/s,1,62ms,sps_ao_ppb_1_fast_export_furious_recompiles,49.61470317840576,38ms,1,1,22ms,sps,,23ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_furious_inductor_cache_dir'},0.0,None,8124ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,28ms,44.12568783760071ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1686,8124ms,, +None,None,1,1794153472,0,26ms,29ms,27ms,0.9996936089992523,16ms,25.35749101638794s,2.7.0.dev20250201+cu124,39.436078252131644img/s,1,20ms,sps_ao_ppb_1_fast_export_furious_gpu_preproc,29.401476621627808,20ms,1,1,20ms,sps,,21ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_furious_inductor_cache_dir'},0.0,,662ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,19ms,25.35749101638794ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1711,662ms,, +None,None,1,1794153472,0,26ms,31ms,27ms,0.22546337781244644,17ms,26.919757604599s,2.7.0.dev20250201+cu124,37.14743701218019img/s,1,21ms,sps_ao_ppb_1_fast_export_furious_gpu_preproc_recompiles,32.35977077484131,22ms,1,1,20ms,sps,,21ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_furious_inductor_cache_dir'},0.0,None,2134ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,19ms,26.919757604599ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1711,2134ms,, +,,,1402492416,126,775ms,1593ms,1171ms,,150ms,331.5782699584961s,2.7.0.dev20250201+cu124,3.0158791772608344img/s,,289ms,baseline_mps,335.87450075149536,324ms,1,1,304ms,mps,,541ms,None,,,1991ms,,0.22.0.dev20250201+cu124,258ms,331.5782699584961ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1337,611ms,None, +,,,8411175424,0,227ms,311ms,239ms,0.999999164044857,105ms,143.97097539901733s,2.7.0.dev20250201+cu124,6.945844446969173img/s,,127ms,mps_ao,148.60355854034424,137ms,1,8,117ms,mps,,127ms,None,0.0,,634ms,,0.22.0.dev20250201+cu124,122ms,143.97097539901733ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,8021,634ms,, +,,,8411175424,0,234ms,309ms,259ms,0.999999164044857,221ms,164.95788407325745s,2.7.0.dev20250201+cu124,6.062153413388245img/s,1,234ms,mps_ao_ppb_None_basic,168.8498158454895,158ms,1,8,231ms,mps,,242ms,None,0.0,,644ms,,0.22.0.dev20250201+cu124,135ms,164.95788407325745ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,8021,644ms,, +,None,,8411176448,0,220ms,54779ms,243ms,,209ms,568.1692686080933s,2.7.0.dev20250201+cu124,1.7600388744181994img/s,1,1564ms,mps_ao_ppb_None_fast_cold,577.6140518188477,561ms,1,8,130ms,mps,,214ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_inductor_cache_dir'},,,332350ms,,0.22.0.dev20250201+cu124,115ms,568.1692686080933ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,8021,332350ms,, +,None,,8411176448,0,221ms,1345ms,240ms,0.9983834705352783,97ms,165.37928342819214s,2.7.0.dev20250201+cu124,6.0467065721336315img/s,1,580ms,mps_ao_ppb_None_fast,170.9393391609192,155ms,1,8,109ms,mps,,144ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_inductor_cache_dir'},0.0,,9522ms,,0.22.0.dev20250201+cu124,126ms,165.37928342819214ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,8021,9522ms,, +,,,1390578176,,,,,,,206.4340798854828s,2.7.0.dev20250201+cu124,0.0img/s,1,,mps_ao_ppb_None_save_export,217.42104578018188,,1,1,,mps,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1326,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast +,,,7556661248,0,218ms,322ms,236ms,0.998383426964283,104ms,138.59291863441467s,2.7.0.dev20250201+cu124,7.215375863739731img/s,1,116ms,mps_ao_ppb_None_load_export_cold,143.01005744934082,131ms,1,7,112ms,mps,,122ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_inductor_cache_dir'},0.0,,579ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,115ms,138.59291863441467ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7206,579ms,, +,,,7556661248,0,218ms,258ms,237ms,0.998383426964283,97ms,136.831298828125s,2.7.0.dev20250201+cu124,7.308269442476818img/s,1,116ms,mps_ao_ppb_None_load_export,141.67460775375366,129ms,1,7,111ms,mps,,120ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_inductor_cache_dir'},0.0,,589ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,114ms,136.831298828125ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7206,589ms,, +,,,7581827072,0,190ms,374ms,216ms,0.9984678273200989,170ms,149.05044078826904s,2.7.0.dev20250201+cu124,6.70913815961492img/s,1,187ms,mps_ao_ppb_None_load_export_gpu_preproc,153.32005190849304,142ms,1,7,181ms,mps,,143ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_inductor_cache_dir'},0.0,,596ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,135ms,149.05044078826904ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7230,596ms,, +,None,,7556661248,0,208ms,54466ms,226ms,0.9983833708167076,188ms,287.1738612651825s,2.7.0.dev20250201+cu124,3.482211074484173img/s,1,131ms,mps_ao_ppb_None_fast_export_cold,295.3504989147186,278ms,1,7,108ms,mps,,140ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_inductor_cache_dir'},0.0,,62539ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,109ms,287.1738612651825ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7206,62539ms,, +,None,,7556661248,0,218ms,1720ms,230ms,0.9983833900690079,195ms,141.05165219306946s,2.7.0.dev20250201+cu124,7.089601464796843img/s,1,230ms,mps_ao_ppb_None_fast_export,147.43897795677185,133ms,1,7,216ms,mps,,222ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_inductor_cache_dir'},0.0,,3561ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,111ms,141.05165219306946ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7206,3561ms,, +,None,,7581827072,0,185ms,1572ms,197ms,0.9984678581357003,94ms,148.53872227668762s,2.7.0.dev20250201+cu124,6.73225125861302img/s,1,107ms,mps_ao_ppb_None_fast_export_gpu_preproc,154.97156023979187,141ms,1,7,105ms,mps,,112ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_inductor_cache_dir'},0.0,,4246ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,127ms,148.53872227668762ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7230,4246ms,, +None,None,,4427842560,0,74ms,63302ms,84ms,0.9964296479523181,22ms,723.8993864059448s,2.7.0.dev20250201+cu124,1.3814074424967462img/s,1,1071ms,mps_ao_ppb_None_fast_furious_cold,733.4108500480652,716ms,1,4,29ms,mps,,37ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_furious_inductor_cache_dir'},0.0,,581345ms,,0.22.0.dev20250201+cu124,49ms,723.8993864059448ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,4222,581345ms,, +None,None,,4427842560,0,74ms,1300ms,85ms,0.9964293534457683,20ms,58.8767945766449s,2.7.0.dev20250201+cu124,16.9846202937936img/s,1,350ms,mps_ao_ppb_None_fast_furious,64.73449230194092,51ms,1,4,29ms,mps,,30ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_furious_inductor_cache_dir'},0.0,,8402ms,,0.22.0.dev20250201+cu124,34ms,58.8767945766449ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,4222,8402ms,, +None,,,903450624,,,,,,,315.72570967674255s,2.7.0.dev20250201+cu124,0.0img/s,1,,mps_ao_ppb_None_save_export_furious,324.74191069602966,,1,0,,mps,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_furious_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,861,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious +None,,,3998911488,0,82ms,301ms,90ms,0.9955771351754665,41ms,57.82986092567444s,2.7.0.dev20250201+cu124,17.292104528579888img/s,1,38ms,mps_ao_ppb_None_load_export_furious_cold,62.62674617767334,51ms,1,3,37ms,mps,,40ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_furious_inductor_cache_dir'},0.0,,754ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,46ms,57.82986092567444ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3813,754ms,, +None,,,3998911488,0,88ms,252ms,97ms,0.9955771351754665,32ms,65.55874681472778s,2.7.0.dev20250201+cu124,15.25349474458456img/s,1,80ms,mps_ao_ppb_None_load_export_furious,70.35485363006592,58ms,1,3,39ms,mps,,40ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_furious_inductor_cache_dir'},0.0,,875ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,53ms,65.55874681472778ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3813,875ms,, +None,,,4024077312,0,45ms,285ms,56ms,0.9959434471726417,29ms,41.67199182510376s,2.7.0.dev20250201+cu124,23.996933100701625img/s,1,35ms,mps_ao_ppb_None_load_export_furious_gpu_preproc,46.09472918510437,35ms,1,3,35ms,mps,,36ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_furious_inductor_cache_dir'},0.0,,653ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,32ms,41.67199182510376ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3837,653ms,, +None,None,,3998911488,0,68ms,51237ms,77ms,0.9966195167303086,20ms,211.8625111579895s,2.7.0.dev20250201+cu124,4.720042231795708img/s,1,27ms,mps_ao_ppb_None_fast_export_furious_cold,218.6763949394226,204ms,1,3,30ms,mps,,66ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_furious_inductor_cache_dir'},0.0,,79408ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,32ms,211.8625111579895ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3813,79408ms,, +None,None,,3998911488,0,70ms,1746ms,78ms,0.9966195802688599,59ms,51.70280361175537s,2.7.0.dev20250201+cu124,19.341310918246524img/s,1,43ms,mps_ao_ppb_None_fast_export_furious,57.28682208061218,44ms,1,3,34ms,mps,,70ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_furious_inductor_cache_dir'},0.0,,3842ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,35ms,51.70280361175537ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3813,3842ms,, +None,None,,3998911488,0,65ms,6664ms,75ms,0.9956195802688599,20ms,59.52086091041565s,2.7.0.dev20250201+cu124,16.8008322578716img/s,1,56ms,mps_ao_ppb_None_fast_export_furious_recompiles,64.74269723892212,52ms,1,3,27ms,mps,,29ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_furious_inductor_cache_dir'},0.0,None,11728ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,30ms,59.52086091041565ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3813,11728ms,, +None,None,,4024077312,0,37ms,1743ms,46ms,0.9960403459072114,19ms,37.689289808273315s,2.7.0.dev20250201+cu124,26.5327366232432img/s,1,26ms,mps_ao_ppb_None_fast_export_furious_gpu_preproc,42.8827166557312,31ms,1,3,27ms,mps,,30ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_furious_inductor_cache_dir'},0.0,,3914ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,23ms,37.689289808273315ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3837,3914ms,, +None,None,,4024077312,0,35ms,1672ms,43ms,0.9950685520768165,22ms,44.08118724822998s,2.7.0.dev20250201+cu124,22.685414400678457img/s,1,26ms,mps_ao_ppb_None_fast_export_furious_gpu_preproc_recompiles,50.419389486312866,36ms,1,3,26ms,mps,,31ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_furious_inductor_cache_dir'},0.0,None,9520ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,23ms,44.08118724822998ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3837,9520ms,, diff --git a/torchao/_models/sam2/automatic_mask_generator.py b/torchao/_models/sam2/automatic_mask_generator.py index 665a211035..6f4f1d3e7b 100644 --- a/torchao/_models/sam2/automatic_mask_generator.py +++ b/torchao/_models/sam2/automatic_mask_generator.py @@ -538,11 +538,11 @@ def _process_batch_fullgraph( ] image_embed_input = image_embed[-1].unsqueeze(0).clone() low_res_masks, iou_preds = self.predictor._predict_masks( - high_res_feats_input, - image_embed_input, - image_pe, - in_points[:, None, :], - in_labels[:, None], + [t.contiguous() for t in high_res_feats_input], + image_embed_input.contiguous(), + image_pe.contiguous(), + in_points[:, None, :].contiguous(), + in_labels[:, None].contiguous(), boxes=None, mask_input=None, multimask_output=self.multimask_output, diff --git a/torchao/_models/sam2/modeling/sam/prompt_encoder.py b/torchao/_models/sam2/modeling/sam/prompt_encoder.py index 6bb58d62ba..94b7fda8b2 100644 --- a/torchao/_models/sam2/modeling/sam/prompt_encoder.py +++ b/torchao/_models/sam2/modeling/sam/prompt_encoder.py @@ -186,6 +186,12 @@ def forward( torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W) """ + # if boxes is not None: + # raise ValueError("Currently do not support boxes. " + # "Please create an issue on pytorch/ao.") + # if masks is not None: + # raise ValueError("Currently do not support masks. " + # "Please create an issue on pytorch/ao.") bs = self._get_batch_size(points, boxes, masks) sparse_embeddings = torch.empty( (bs, 0, self.embed_dim), device=self._get_device() diff --git a/torchao/_models/sam2/sam2_image_predictor.py b/torchao/_models/sam2/sam2_image_predictor.py index 02d9aed547..a4aa1c668c 100644 --- a/torchao/_models/sam2/sam2_image_predictor.py +++ b/torchao/_models/sam2/sam2_image_predictor.py @@ -430,12 +430,15 @@ def _predict( for feat_level in high_res_feats ] image_embed_input = image_embed[img_idx].unsqueeze(0).clone() + assert boxes is None + assert mask_input is None + assert multimask_output is True low_res_masks, iou_predictions = self._predict_masks( - high_res_feats_input, - image_embed_input, - image_pe, - point_coords, - point_labels, + [t.contiguous() for t in high_res_feats_input], + image_embed_input.contiguous(), + image_pe.contiguous(), + point_coords.contiguous(), + point_labels.contiguous(), boxes=boxes, mask_input=mask_input, multimask_output=multimask_output, @@ -498,6 +501,10 @@ def _predict_masks( # ] high_res_features = high_res_feats_input with torch.autograd.profiler.record_function("self.model.sam_mask_decoder"): + # if not multimask_output: + # raise ValueError("Expected multimask_output.") + # if batched_mode: + # raise ValueError("Did not expected repeat_image.") low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( # image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0).clone(), # image_embeddings=image_embed[img_idx].unsqueeze(0).clone(), diff --git a/torchao/_models/sam2/utils/transforms.py b/torchao/_models/sam2/utils/transforms.py index 95970ba108..c616233050 100644 --- a/torchao/_models/sam2/utils/transforms.py +++ b/torchao/_models/sam2/utils/transforms.py @@ -27,11 +27,10 @@ def __init__( self.mean = [0.485, 0.456, 0.406] self.std = [0.229, 0.224, 0.225] self.to_tensor = ToTensor() - self.transforms = torch.jit.script( - nn.Sequential( - Resize((self.resolution, self.resolution)), - Normalize(self.mean, self.std), - ) + # self.transforms = torch.jit.script( + self.transforms = nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), ) def __call__(self, x): From 4df4d031adbadbbe99451241f82fe3ed9d446a8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= <115986737+alexsamardzic@users.noreply.github.com> Date: Wed, 5 Feb 2025 23:27:54 +0100 Subject: [PATCH 054/115] Moved CUTLASS pin to v3.7.0 (#1672) --- third_party/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cutlass b/third_party/cutlass index bf9da7b76c..b78588d163 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit bf9da7b76c766d7ee7d536afc77880a4ef1f1156 +Subproject commit b78588d1630aa6643bf021613717bafb705df4ef From bc1530b80a24db8c2bb9225709026560ebf90531 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 5 Feb 2025 15:55:29 -0800 Subject: [PATCH 055/115] Q dq layout (#1642) * add q-dq layout for ET * up * up * up * up * up * up * up --- .../workflows/torchao_experimental_test.yml | 3 +- torchao/experimental/q_dq_layout.py | 61 ++++++ ...est_int8_dynamic_activation_intx_weight.py | 186 ++++++++++++++++++ ...8_dynamic_activation_intx_weight_layout.py | 154 --------------- 4 files changed, 249 insertions(+), 155 deletions(-) create mode 100644 torchao/experimental/q_dq_layout.py create mode 100644 torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py delete mode 100644 torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index c1419bccc6..08f494c71d 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -35,8 +35,9 @@ jobs: conda activate venv pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" torch=="2.6.0.dev20250104" pip install numpy + pip install pytest USE_CPP=1 pip install . - name: Run tests run: | conda activate venv - python torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py + pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py diff --git a/torchao/experimental/q_dq_layout.py b/torchao/experimental/q_dq_layout.py new file mode 100644 index 0000000000..b9337ae027 --- /dev/null +++ b/torchao/experimental/q_dq_layout.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.affine_quantized_tensor_ops import ( + register_aqt_quantized_linear_dispatch, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +import sys + +handler = logging.StreamHandler(sys.stdout) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +from torchao.dtypes.utils import PlainLayout + + +class QDQLayout(PlainLayout): + pass + + +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl + + +@register_layout(QDQLayout) +class _Impl(PlainAQTTensorImpl): + pass + + +def _linear_check(input_tensor, weight_tensor, bias): + layout = weight_tensor.tensor_impl.get_layout() + return isinstance(layout, QDQLayout) + + +def _linear_impl(input_tensor, weight_tensor, bias): + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +register_aqt_quantized_linear_dispatch( + _linear_check, + _linear_impl, +) diff --git a/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py new file mode 100644 index 0000000000..63a8892425 --- /dev/null +++ b/torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import itertools +import tempfile +import unittest + +import torch +from torch.testing import FileCheck + +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.experimental.q_dq_layout import QDQLayout +from torchao.experimental.quant_api import ( + int8_dynamic_activation_intx_weight, +) +from torchao.quantization.granularity import ( + PerGroup, + PerRow, +) +from torchao.quantization.quant_api import quantize_ +from torchao.utils import unwrap_tensor_subclass + + +class TestInt8DynamicActivationIntxWeight(unittest.TestCase): + def test_accuracy(self): + """ + Checks the accuracy of different layouts by comparing the results to PlainLayout() + """ + m = 1 + n = 1071 + k = 4096 + activations = torch.randn(m, k) + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + reference_layout = PlainLayout() + test_layouts = [ + PackedLinearInt8DynamicActivationIntxWeightLayout(), + QDQLayout(), + ] + test_weight_dtypes = [ + torch.int1, + torch.int2, + torch.int3, + torch.int4, + torch.int5, + torch.int6, + torch.int7, + torch.int8, + ] + test_has_weight_zeros = [True, False] + test_granularities = [PerGroup(128), PerRow()] + for layout, weight_dtype, has_weight_zeros, granularity in itertools.product( + test_layouts, test_weight_dtypes, test_has_weight_zeros, test_granularities + ): + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=layout, + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=reference_layout, + ), + ) + + with torch.no_grad(): + result = quantized_model(activations) + expected_result = quantized_model_reference(activations) + self.assertTrue(torch.allclose(result, expected_result, atol=1e-6)) + + def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout( + self, + ): + """ + Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with + torch.export.export, torch.compile, and AOTI. + """ + granularity = PerRow() + m = 3 + k0 = 512 + k1 = 256 + k2 = 128 + k3 = 1024 + weight_dtype = torch.int4 + has_weight_zeros = True + layers = [ + torch.nn.Linear(k0, k1, bias=False), + torch.nn.Linear(k1, k2, bias=False), + torch.nn.Linear(k2, k3, bias=False), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn(2, 1, m, k0, dtype=torch.float32) + + quantize_( + model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), + ), + ) + eager_results = model(activations) + + unwrapped_model = copy.deepcopy(model) + unwrap_tensor_subclass(model) + + # Export + exported = torch.export.export(model, (activations,), strict=True) + exported_results = exported.module()(activations) + self.assertTrue(torch.allclose(eager_results, exported_results)) + + # Compile + compiled = torch.compile(unwrapped_model) + with torch.no_grad(): + compiled_results = compiled(activations) + self.assertTrue(torch.allclose(eager_results, compiled_results)) + + # AOTI + with tempfile.TemporaryDirectory() as tmpdirname: + package_path = f"{tmpdirname}/model.pt2" + torch._inductor.aoti_compile_and_package( + exported, package_path=package_path + ) + fn = torch._inductor.aoti_load_package(package_path) + aoti_results = fn(activations) + self.assertTrue(torch.allclose(eager_results, aoti_results)) + + def test_export_QDQLayout(self): + """ + Checks that models quantized with TestQDQLayout() export as expected + """ + granularity = PerGroup(64) + weight_dtype = torch.int4 + has_weight_zeros = False + layers = [ + torch.nn.Linear(512, 256, bias=False), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn(1, 512, dtype=torch.float32) + + quantize_( + model, + int8_dynamic_activation_intx_weight( + weight_dtype=weight_dtype, + granularity=granularity, + has_weight_zeros=has_weight_zeros, + layout=QDQLayout(), + ), + ) + eager_results = model(activations) + + unwrap_tensor_subclass(model) + exported = torch.export.export(model, (activations,), strict=True) + exported_results = exported.module()(activations) + self.assertTrue(torch.allclose(eager_results, exported_results)) + + expected_lines = [ + "torch.ops.quant.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int32, -128, 127, None, torch.float32, torch.int32)", + "torch.ops.quant.quantize_affine.default(input_1, [1, 512], getitem, getitem_1, torch.int32, -128, 127)", + "torch.ops.quant.dequantize_affine.default(quantize_affine, [1, 512], getitem, getitem_1, torch.int32, -128, 127)", + "torch.ops.quant.dequantize_affine.default(p_fn_0_parametrizations_weight_original0, [1, 64], p_fn_0_parametrizations_weight_original1, None, torch.int32, -8, 7, 'NONE')", + "torch.ops.aten.linear.default(dequantize_affine, dequantize_affine_1)", + ] + for line in expected_lines: + FileCheck().check_count(line, 1, exactly=True).run( + exported.graph_module.code + ) diff --git a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py deleted file mode 100644 index 284ef4b2a8..0000000000 --- a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import copy -import tempfile -import unittest - -import torch - -from torchao.dtypes import PlainLayout -from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( - PackedLinearInt8DynamicActivationIntxWeightLayout, -) -from torchao.experimental.quant_api import ( - int8_dynamic_activation_intx_weight, -) -from torchao.quantization.granularity import ( - PerGroup, - PerRow, -) -from torchao.quantization.quant_api import quantize_ -from torchao.utils import unwrap_tensor_subclass - - -class TestPackedLinearInt8DynamicActivationIntxWeightLayout(unittest.TestCase): - def test_accuracy(self): - """ - Checks the accuracy of PackedLinearInt8DynamicActivationIntxWeightLayout() by comparing - its results to the results of a reference model that uses PlainLayout() - """ - granularity = PerGroup(128) - m = 1 - n = 1071 - k = 4096 - activations = torch.randn(m, k) - model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) - - for weight_dtype in [ - torch.int1, - torch.int2, - torch.int3, - torch.int4, - torch.int5, - torch.int6, - torch.int7, - torch.int8, - ]: - for has_weight_zeros in [True, False]: - print( - f"Testing weight_dtype={weight_dtype}, has_weight_zeros={has_weight_zeros}" - ) - quantized_model = copy.deepcopy(model) - quantize_( - quantized_model, - int8_dynamic_activation_intx_weight( - weight_dtype=weight_dtype, - granularity=granularity, - has_weight_zeros=has_weight_zeros, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # default - ), - ) - - quantized_model_reference = copy.deepcopy(model) - quantize_( - quantized_model_reference, - int8_dynamic_activation_intx_weight( - weight_dtype=weight_dtype, - granularity=granularity, - has_weight_zeros=has_weight_zeros, - layout=PlainLayout(), - ), - ) - - with torch.no_grad(): - result = quantized_model(activations) - expected_result = quantized_model_reference(activations) - - num_mismatch_at_low_tol = 0 - num_total = result.reshape(-1).shape[0] - for i in range(num_total): - actual_val = result.reshape(-1)[i] - expected_val = expected_result.reshape(-1)[i] - self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) - if not torch.allclose(actual_val, expected_val): - num_mismatch_at_low_tol += 1 - - # Assert at most 5% of entries are not close at a low tolerance - self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) - - def test_export_compile_aoti(self): - """ - Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with - torch.export.export, torch.compile, and AOTI. - """ - granularity = PerRow() - m = 3 - k0 = 512 - k1 = 256 - k2 = 128 - k3 = 1024 - weight_dtype = torch.int4 - has_weight_zeros = True - layers = [ - torch.nn.Linear(k0, k1, bias=False), - torch.nn.Linear(k1, k2, bias=False), - torch.nn.Linear(k2, k3, bias=False), - ] - model = torch.nn.Sequential(*layers) - activations = torch.randn(2, 1, m, k0, dtype=torch.float32) - - print("Quantizing model") - quantize_( - model, - int8_dynamic_activation_intx_weight( - weight_dtype=weight_dtype, - granularity=granularity, - has_weight_zeros=has_weight_zeros, - layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), - ), - ) - eager_results = model(activations) - - unwrapped_model = copy.deepcopy(model) - unwrap_tensor_subclass(model) - - print("Exporting quantized model") - exported = torch.export.export(model, (activations,), strict=True) - exported_results = exported.module()(activations) - self.assertTrue(torch.allclose(eager_results, exported_results)) - - print("Compiling quantized model") - compiled = torch.compile(unwrapped_model) - with torch.no_grad(): - compiled_results = compiled(activations) - self.assertTrue(torch.allclose(eager_results, compiled_results)) - - with tempfile.TemporaryDirectory() as tmpdirname: - package_path = f"{tmpdirname}/model.pt2" - print("Exporting quantized model with AOTI") - torch._inductor.aoti_compile_and_package( - exported, package_path=package_path - ) - - print("Running quantized model in AOTI") - fn = torch._inductor.aoti_load_package(package_path) - aoti_results = fn(activations) - self.assertTrue(torch.allclose(eager_results, aoti_results)) - - -if __name__ == "__main__": - unittest.main() From c6611be254be9563d045f515d94c20c8c54be8ec Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Wed, 5 Feb 2025 16:01:48 -0800 Subject: [PATCH 056/115] Remove duplicate definitions of fill_defaults (#1674) --- torchao/dtypes/uintx/uint4_layout.py | 27 ++------------------------- torchao/prototype/dtypes/uint2.py | 11 ++--------- 2 files changed, 4 insertions(+), 34 deletions(-) diff --git a/torchao/dtypes/uintx/uint4_layout.py b/torchao/dtypes/uintx/uint4_layout.py index 204aefcf3c..0b6512640e 100644 --- a/torchao/dtypes/uintx/uint4_layout.py +++ b/torchao/dtypes/uintx/uint4_layout.py @@ -3,6 +3,8 @@ import torch.utils._pytree as pytree from torch.library import Library, impl +from torchao.utils import fill_defaults + def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" @@ -13,31 +15,6 @@ def up_size(size): return (*size[:-1], size[-1] * 2) -def fill_defaults(args, n, defaults_tail): - """ - __torch_dispatch__ doesn't guarantee the number of arguments you are - passed (e.g., defaulted arguments are not passed); but usually it is - convenient to pad out the arguments list with defaults. This function - helps you do that. - Args: - args: the list of positional arguments passed to __torch_dispatch__ - n: the number of arguments you are expecting to get - defaults_tail: default values for the arguments, starting from the - end of the list - Example: - >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) - [1, 2, 3, 4, 5] - >>> fill_defaults([1, 2, 3], 5, [None, None, None]) - [1, 2, 3, None, None]] - """ - if n - len(defaults_tail) > len(args): - raise RuntimeError("not enough defaults to fill arguments") - r = list(args) - for i in range(len(args), n): - r.append(defaults_tail[i - n + len(defaults_tail)]) - return r - - # from # https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233 diff --git a/torchao/prototype/dtypes/uint2.py b/torchao/prototype/dtypes/uint2.py index 9c14d8ae72..d54e541751 100644 --- a/torchao/prototype/dtypes/uint2.py +++ b/torchao/prototype/dtypes/uint2.py @@ -4,16 +4,9 @@ import torch import torch._prims_common as utils -UINT2_OPS_TABLE: Dict[Any, Any] = {} - +from torchao.utils import fill_defaults -def fill_defaults(args, n, defaults_tail): - if n - len(defaults_tail) > len(args): - raise RuntimeError("not enough defaults to fill arguments") - r = list(args) - for i in range(len(args), n): - r.append(defaults_tail[i - n + len(defaults_tail)]) - return r +UINT2_OPS_TABLE: Dict[Any, Any] = {} def implements(aten_ops): From 867a91f930d16f1a79eda3c2d505851e3817b786 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Wed, 5 Feb 2025 23:32:29 -0500 Subject: [PATCH 057/115] update notify in build_wheels_linux.yml (#1676) remove debug code --- .github/workflows/build_wheels_linux.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_wheels_linux.yml b/.github/workflows/build_wheels_linux.yml index 8b966059f3..fd16bf37a8 100644 --- a/.github/workflows/build_wheels_linux.yml +++ b/.github/workflows/build_wheels_linux.yml @@ -70,7 +70,7 @@ jobs: password: ${{ secrets.TORCHAO_NOTIFY_PASSWORD }} from: torchao.notify@gmail.com to: ${{ secrets.TORCHAO_NOTIFY_RECIPIENT }} - subject: breakbutterflyScheduled Build Failure for TorchAO + subject: Scheduled Build Failure for TorchAO body: | Build Failure Notification for TorchAO From 1d75c8fb46c58ac1f6ed641f93ba6a0ca78b33e8 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Thu, 6 Feb 2025 18:20:08 +0000 Subject: [PATCH 058/115] Support mixed MX element dtype in `mx_mm` function and `MXLinear`. (#1667) * Support mixed MX element dtype in `mx_mm` function. Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients. * Support (input, weight, gradient) element dtype tuple in MXLinear layer factory method. Passing a tuple of 3 element dtypes avoids introducing a breaking change in the current interface of `MXLinear` and `swap_linear_with_mx_linear`. Some additional unit test coverage has been added on MXLinear. * Using default `elem_dtype` argument and optional weight/grad overrides. --- test/prototype/mx_formats/test_mx_linear.py | 32 +++++++-- torchao/prototype/mx_formats/README.md | 9 ++- torchao/prototype/mx_formats/mx_linear.py | 73 ++++++++++++++++----- 3 files changed, 88 insertions(+), 26 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 35afeb7959..17a76a750d 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy +import itertools import pytest import torch @@ -41,13 +42,16 @@ def run_around_tests(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +@pytest.mark.parametrize( + "elem_dtype", itertools.product(SUPPORTED_ELEM_DTYPES, repeat=3) +) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)]) def test_linear_eager(elem_dtype, bias, input_shape): """ Smoke test for training linear module with mx weight """ + # elem_dtype is a tuple of (input, weight, gradient) dtypes. grad_shape = list(input_shape) grad_shape[-1] = 6 @@ -56,7 +60,7 @@ def test_linear_eager(elem_dtype, bias, input_shape): ) m_mx = copy.deepcopy(m) block_size = 2 - swap_linear_with_mx_linear(m_mx, elem_dtype, block_size) + swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size) x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() x = copy.deepcopy(x_ref) @@ -72,7 +76,7 @@ def test_linear_eager(elem_dtype, bias, input_shape): w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad) x_g_sqnr = compute_error(x_ref.grad, x.grad) - if elem_dtype is torch.float8_e4m3fn: + if elem_dtype == (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn): assert y_sqnr >= 18.0 assert w_g_sqnr >= 18.0 assert x_g_sqnr >= 12.0 @@ -94,7 +98,7 @@ def test_activation_checkpointing(): nn.Linear(6, 6, bias=True, device="cuda"), ) block_size = 2 - swap_linear_with_mx_linear(m, elem_dtype, block_size) + swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size) x = torch.randn(*input_shape, device="cuda").requires_grad_() g = torch.randn(*grad_shape, device="cuda") @@ -130,7 +134,7 @@ def test_linear_compile(elem_dtype, bias, use_autocast): nn.Linear(K, N, bias=bias, device="cuda"), ) block_size = 2 - swap_linear_with_mx_linear(m_mx, elem_dtype, block_size) + swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size) m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") @@ -219,6 +223,20 @@ def test_inference_compile_simple(elem_dtype): assert sqnr >= 13.5 +def test_mx_linear_input_weight_gradient_dtypes(): + m = nn.Sequential(nn.Linear(32, 32)) + swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32) + assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0] + assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1] + assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2] + + m = nn.Sequential(nn.Linear(32, 32)) + swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32) + assert m[0].in_elem_dtype == torch.float8_e4m3fn + assert m[0].w_elem_dtype == torch.float8_e4m3fn + assert m[0].grad_elem_dtype == torch.float8_e4m3fn + + def test_filter_fn(): m1 = nn.Sequential( nn.Linear(32, 32), @@ -227,7 +245,9 @@ def test_filter_fn(): m2 = copy.deepcopy(m1) filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731 - swap_linear_with_mx_linear(m1, torch.float8_e4m3fn, 32, filter_fn) + swap_linear_with_mx_linear( + m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn + ) assert type(m1[0]) == MXLinear assert type(m1[1]) == torch.nn.Linear diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index b750c26af2..32f45e3755 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -2,8 +2,8 @@ This is a POC of training and inference with tensors in the MX format from the OCP spec (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) in native PyTorch. -Note that the current version of the code is written for readability and -numerical correctness and not yet for optimal performance. We welcome +Note that the current version of the code is written for readability and +numerical correctness and not yet for optimal performance. We welcome contributions on performance improvements. Note that there are no BC guarantees at the moment and we plan to evolve @@ -44,8 +44,7 @@ from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() elem_dtype = torch.float8_e4m3fn -block_size = 32 -swap_linear_with_mx_linear(m, elem_dtype, block_size) +swap_linear_with_mx_linear(m, elem_dtype, block_size=32) # training loop (not shown) ``` @@ -93,7 +92,7 @@ python torchao/prototype/mx_formats/benchmarks/bench_qdq.py ## floating point format convenience functions -We have a convenience script which summarizes the various properties of +We have a convenience script which summarizes the various properties of floating point formats: ```bash diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index b69441e018..d7aa744334 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function): # 1. input @ weight_t = output (forward pass) # 2. grad_output @ weight = grad_input (backward pass) # 3. input_t @ grad_output = grad_weight (backward pass) + # + # input, weight and grad_output can have each their own MX element dtype. @staticmethod def forward( ctx, input_hp: torch.Tensor, weight_hp: torch.Tensor, - elem_dtype: Any, + in_elem_dtype: Any, + w_elem_dtype: Any, + grad_elem_dtype: Any, block_size: int, ): ctx.save_for_backward(input_hp, weight_hp) - ctx.elem_dtype = elem_dtype + ctx.in_elem_dtype = in_elem_dtype + ctx.w_elem_dtype = w_elem_dtype + ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size # input @ weight_t = output input_orig_shape = input_hp.shape input_hp_r = input_hp.reshape(-1, input_orig_shape[-1]) - input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, elem_dtype, block_size) - weight_mx_dim0 = MXTensor.to_mx(weight_hp, elem_dtype, block_size) + input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, in_elem_dtype, block_size) + weight_mx_dim0 = MXTensor.to_mx(weight_hp, w_elem_dtype, block_size) output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t()) output = output.reshape(*input_orig_shape[:-1], output.shape[-1]) @@ -51,7 +57,9 @@ def forward( def backward(ctx, grad_output_hp: torch.Tensor): input_hp, weight_hp = ctx.saved_tensors weight_hp_t_c = weight_hp.t().contiguous() - elem_dtype = ctx.elem_dtype + in_elem_dtype = ctx.in_elem_dtype + w_elem_dtype = ctx.w_elem_dtype + grad_elem_dtype = ctx.grad_elem_dtype block_size = ctx.block_size grad_output_orig_shape = grad_output_hp.shape @@ -61,8 +69,10 @@ def backward(ctx, grad_output_hp: torch.Tensor): input_hp_r = input_hp.reshape(-1, input_hp_orig_shape[-1]) # grad_output @ weight = grad_input - grad_output_mx_dim0 = MXTensor.to_mx(grad_output_hp_r, elem_dtype, block_size) - weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, elem_dtype, block_size) + grad_output_mx_dim0 = MXTensor.to_mx( + grad_output_hp_r, grad_elem_dtype, block_size + ) + weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size) grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) grad_input = grad_input.reshape( *grad_output_orig_shape[:-1], grad_input.shape[-1] @@ -70,15 +80,15 @@ def backward(ctx, grad_output_hp: torch.Tensor): # input_t @ grad_output = grad_weight grad_output_mx_dim1 = MXTensor.to_mx( - grad_output_hp_r.t().contiguous(), elem_dtype, block_size + grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size ) input_t_mx_dim0_tmp = MXTensor.to_mx( - input_hp_r.t().contiguous(), elem_dtype, block_size + input_hp_r.t().contiguous(), in_elem_dtype, block_size ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0) - return grad_input, grad_weight, None, None + return grad_input, grad_weight, None, None, None, None class MXLinear(torch.nn.Linear): @@ -87,13 +97,25 @@ class MXLinear(torch.nn.Linear): matmul is emulated since there is no hardware support yet. Activations, weights and grads are casted to MX and back to high precision for each matmul. + + Input, weight and grad_output can have each their own MX element dtype. """ @classmethod @torch.no_grad() - def from_float(cls, mod, elem_dtype, block_size): + def from_float( + cls, + mod, + elem_dtype, + elem_dtype_weight_override=None, + elem_dtype_grad_output_override=None, + *, + block_size=32, + ): mod.__class__ = MXLinear - mod.elem_dtype = elem_dtype + mod.in_elem_dtype = elem_dtype + mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype + mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype mod.block_size = block_size return mod @@ -106,7 +128,14 @@ def forward(self, x): else: w = self.weight - y = mx_mm.apply(x, w, self.elem_dtype, self.block_size) + y = mx_mm.apply( + x, + w, + self.in_elem_dtype, + self.w_elem_dtype, + self.grad_elem_dtype, + self.block_size, + ) if self.bias is not None: y = y + self.bias return y @@ -172,7 +201,15 @@ def _is_linear(mod, fqn): return isinstance(mod, torch.nn.Linear) -def swap_linear_with_mx_linear(model, elem_dtype, block_size, filter_fn=None): +def swap_linear_with_mx_linear( + model, + elem_dtype, + elem_dtype_weight_override=None, + elem_dtype_grad_output_override=None, + *, + block_size=32, + filter_fn=None, +): if filter_fn is None: combined_filter_fn = _is_linear else: @@ -183,7 +220,13 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXLinear.from_float(mod, elem_dtype, block_size), + lambda mod: MXLinear.from_float( + mod, + elem_dtype, + elem_dtype_weight_override, + elem_dtype_grad_output_override, + block_size=block_size, + ), combined_filter_fn, ) From 753ba98706cd02ab4e5b6cba76815ed594daeb67 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Thu, 6 Feb 2025 15:08:46 -0800 Subject: [PATCH 059/115] Test fix (#1678) --- .github/workflows/build_wheels_aarch64_linux.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_wheels_aarch64_linux.yml b/.github/workflows/build_wheels_aarch64_linux.yml index 56ea528a69..0f64aa53bf 100644 --- a/.github/workflows/build_wheels_aarch64_linux.yml +++ b/.github/workflows/build_wheels_aarch64_linux.yml @@ -29,7 +29,8 @@ jobs: test-infra-repository: pytorch/test-infra test-infra-ref: main with-cuda: disable - + # please note: excluding 3.13t for aarch64 builds for now + python-versions: '["3.9", "3.10", "3.11", "3.12", "3.13"]' build: needs: generate-matrix permissions: From d1e6c03b6d28f6dab3d9f55ff828f95a37e1acc8 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Thu, 6 Feb 2025 15:41:29 -0800 Subject: [PATCH 060/115] CI fix for linux wheels (#1679) --- .github/workflows/build_wheels_linux.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_wheels_linux.yml b/.github/workflows/build_wheels_linux.yml index fd16bf37a8..96801257da 100644 --- a/.github/workflows/build_wheels_linux.yml +++ b/.github/workflows/build_wheels_linux.yml @@ -30,6 +30,8 @@ jobs: with-cuda: enable with-rocm: enable with-xpu: enable + # please note: excluding 3.13t for aarch64 builds for now + python-versions: '["3.9", "3.10", "3.11", "3.12", "3.13"]' build: needs: generate-matrix @@ -89,5 +91,5 @@ jobs: Error Information: ${{ needs.generate-matrix.result == 'failure' && 'Matrix generation failed' || '' }} ${{ needs.build.result == 'failure' && 'Build job failed' || '' }} - + This is an automated notification. Please check the GitHub Actions page for more details about the failure. From cc6244c864416926877fc469f6d46db900a90f61 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Thu, 6 Feb 2025 19:05:06 -0800 Subject: [PATCH 061/115] Add boiler plate code to Tensor subclass (#1663) --- torchao/utils.py | 57 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/torchao/utils.py b/torchao/utils.py index f67463f9f7..13b59c2e81 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -512,6 +512,27 @@ def _get_tensor_impl_constructor( return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] +def _get_to_kwargs(self, *args, **kwargs): + # `torch._C._nn._parse_to` can't handle `layout` argument + for arg in args: + if isinstance(arg, torch.layout): + args.remove(arg) + if "layout" in kwargs: + kwargs.pop("layout") + # ignoring `non_blocking` and `memory_format` args since these are not + # very useful for most of the tensor subclasses + # if in the future there are use cases that need these, we'd recommend + # to override `_get_to_kwargs` and return these args + device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + kwargs = { + "device": device, + "dtype": dtype, + } + return kwargs + + class TorchAOBaseTensor(torch.Tensor): """A util tensor subclass that provides commonly used functions new tensor subclass can inherit it to get all the utility functions @@ -552,26 +573,24 @@ class PlainAQTTensorImpl(...): __torch_function__ = classmethod(_dispatch__torch_function__) register_layout = classmethod(_register_layout) get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor) + _get_to_kwargs = _get_to_kwargs + + def __tensor_flatten__(self): + raise NotImplementedError("Subclasses must implement __tensor_flatten__") + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + raise NotImplementedError("Subclasses must implement __tensor_unflatten__") + + def __repr__(self): + raise NotImplementedError("Subclasses must implement __repr__") - def _get_to_kwargs(self, *args, **kwargs): - # `torch._C._nn._parse_to` can't handle `layout` argument - for arg in args: - if isinstance(arg, torch.layout): - args.remove(arg) - if "layout" in kwargs: - kwargs.pop("layout") - # ignoring `non_blocking` and `memory_format` args since these are not - # very useful for most of the tensor subclasses - # if in the future there are use cases that need these, we'd recommend - # to override `_get_to_kwargs` and return these args - device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - kwargs = { - "device": device, - "dtype": dtype, - } - return kwargs + def get_layout(self): + if not hasattr(self, "_layout"): + return None + return self._layout def fill_defaults(args, n, defaults_tail): From e7aa4cad812b39e71f69c6d1b3ec8cb61fe9b37f Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 7 Feb 2025 08:44:09 -0800 Subject: [PATCH 062/115] add a deprecation warning for float8 delayed and static scaling (#1681) Update [ghstack-poisoned] --- torchao/float8/README.md | 2 ++ torchao/float8/config.py | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 8487096e6c..ddc717f953 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -65,6 +65,8 @@ for _ in range(10): ## float8 linear with delayed scaling +:warning: We plan to deprecate delayed scaling in a future release, see https://github.com/pytorch/ao/issues/1680 for more details. + This is theoretically the most performant recipe as it minimizes memory reads. ```python diff --git a/torchao/float8/config.py b/torchao/float8/config.py index c7f32cd3fa..fb306e0fb7 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -304,6 +304,16 @@ def __post_init__(self): "When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd." ) + # Future deprecation warning for delayed scaling + if ( + self.cast_config_input.scaling_type != ScalingType.DYNAMIC + or self.cast_config_weight.scaling_type != ScalingType.DYNAMIC + or self.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC + ): + logger.warning( + "Note: delayed and static scaling will be deprecated in a future release of torchao. Please see https://github.com/pytorch/ao/issues/1680 for more details." + ) + # Pre-made recipes for common configurations # TODO(future PR): go through a round of design on this, and eventually expose From c8eb8d31dd8c4ef744e49fa215db439d7d5884f7 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 7 Feb 2025 11:47:25 -0800 Subject: [PATCH 063/115] Lint fixes for fbcode (#1682) --- ...r_int8_dynamic_activation_intx_weight_layout_target_aten.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py index 2a08d0e548..9cf85893ea 100644 --- a/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py +++ b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout_target_aten.py @@ -43,8 +43,7 @@ def test_accuracy(self): for has_weight_zeros in [True]: for granularity in granularities: print( - f"Testing weight_dtype={weight_dtype}, has_weight_zeros={ - has_weight_zeros}, granularity={granularity}" + f"Testing weight_dtype={weight_dtype}, has_weight_zeros={has_weight_zeros}, granularity={granularity}" ) quantized_model = copy.deepcopy(model) quantize_( From 4d1c7741842a1dfbd479b3481fcdc93c64db703e Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Sun, 9 Feb 2025 14:28:05 -0800 Subject: [PATCH 064/115] SAM2: Modal experiments QoL improvements (#1683) --- examples/sam2_amg_server/cli_on_modal.py | 271 +++++++++--------- examples/sam2_amg_server/compare_rle_lists.py | 24 +- examples/sam2_amg_server/modal_experiments.sh | 63 ++-- 3 files changed, 199 insertions(+), 159 deletions(-) diff --git a/examples/sam2_amg_server/cli_on_modal.py b/examples/sam2_amg_server/cli_on_modal.py index 1c384d3288..5fe56eeb1a 100644 --- a/examples/sam2_amg_server/cli_on_modal.py +++ b/examples/sam2_amg_server/cli_on_modal.py @@ -1,12 +1,11 @@ +import asyncio import json -import time from pathlib import Path import fire import modal TARGET = "/root/" -DOWNLOAD_URL_BASE = "https://raw.githubusercontent.com/pytorch/ao/refs/heads" SAM2_GIT_SHA = "c2ec8e14a185632b0a5d8b161928ceb50197eddc" image = ( @@ -25,11 +24,8 @@ .apt_install("git") .apt_install("libopencv-dev") .apt_install("python3-opencv") - .run_commands(["git clone https://github.com/pytorch/ao.git /tmp/ao_src_0"]) - .run_commands( - ["cd /tmp/ao_src_0; git checkout 1be4307db06d2d7e716d599c1091a388220a61e4"] - ) - .run_commands(["cd /tmp/ao_src_0; python setup.py develop"]) + .run_commands([f"git clone https://github.com/pytorch/ao.git {TARGET}ao_src_0"]) + .run_commands([f"cd {TARGET}ao_src_0; python setup.py develop"]) .pip_install( "gitpython", ) @@ -42,9 +38,9 @@ .pip_install_from_requirements( "requirements.txt", ) - # .pip_install( - # f"git+https://github.com/facebookresearch/sam2.git@{SAM2_GIT_SHA}", - # ) + .pip_install( + f"git+https://github.com/facebookresearch/sam2.git@{SAM2_GIT_SHA}", + ) ) app = modal.App("torchao-sam-2-cli", image=image) @@ -62,7 +58,7 @@ @app.cls( gpu="H100", container_idle_timeout=20 * 60, - concurrency_limit=1, + concurrency_limit=10, allow_concurrent_inputs=1, timeout=20 * 60, volumes={ @@ -73,76 +69,38 @@ }, ) class Model: - def calculate_file_hash(self, file_path, hash_algorithm="sha256"): - import hashlib - - """Calculate the hash of a file.""" - hash_func = hashlib.new(hash_algorithm) - with open(file_path, "rb") as f: - for chunk in iter(lambda: f.read(4096), b""): - hash_func.update(chunk) - return hash_func.hexdigest() - - def download_file(self, url, filename): - import subprocess - - command = f"wget -O {filename} {url}" - subprocess.run(command, shell=True, check=True) - - def download_and_verify_file( - self, url, filename, hash_value, hash_algorithm="sha256" - ): - if Path(filename).exists(): - h = self.calculate_file_hash(filename, hash_algorithm) - if hash_value == h: - return - # Here either the file doesn't exist or the file - # has the wrong hash, so we try to download it again. - self.download_file(url, filename) - h = self.calculate_file_hash(filename, hash_algorithm) - if h != hash_value: - raise ValueError( - f"Url {url} doesn't contain file with " - f"{hash_algorithm} hash of value " - f"{hash_value}" - ) + task_type: str = modal.parameter(default="amg") + baseline: int = modal.parameter(default=0) @modal.build() @modal.enter() def build(self): import os - from torchao._models.sam2.automatic_mask_generator import ( - SAM2AutomaticMaskGenerator, - ) - from torchao._models.sam2.build_sam import build_sam2 - # Baseline - # from sam2.build_sam import build_sam2 - # from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator - - download_url_branch = "main" - download_url = f"{DOWNLOAD_URL_BASE}/{download_url_branch}/" - download_url = download_url + "examples/sam2_amg_server" - - file_hashes = { - "cli.py": "8bce88807fe360babd7694f7ee009d7ea6cdc150a4553c41409589ec557b4c4b", - "server.py": "2d79458fabab391ef45cdc3ee9a1b62fea9e7e3b16e0782f522064d6c3c81a17", - "compile_export_utils.py": "552c422a5c267e57d9800e5080f2067f25b4e6a3b871b2063a2840033f4988d0", - "annotate_with_rle.py": "87ecb734c4b2bcdd469e0e373f73727316e844e98f263c6a713c1ce4d6e1f0f6", - "generate_data.py": "5ff754a0845ba0d706226013be2ebf46268a6d46c7bc825ff7dbab0de048a0a7", - } - - for f in file_hashes: - self.download_and_verify_file( - f"{download_url}/{f}", TARGET + f"data/{f}", file_hashes[f] + import numpy as np + import torch + + if self.baseline: + from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator + from sam2.build_sam import build_sam2 + else: + from torchao._models.sam2.automatic_mask_generator import ( + SAM2AutomaticMaskGenerator, ) + from torchao._models.sam2.build_sam import build_sam2 - os.chdir(Path(TARGET + "data")) + os.chdir(f"{TARGET}ao_src_0/examples/sam2_amg_server") import sys sys.path.append(".") - from server import model_type_to_paths + from server import ( + file_bytes_to_image_tensor, + masks_to_rle_dict, + model_type_to_paths, + profiler_runner, + show_anns, + ) device = "cuda" checkpoint_path = Path(TARGET) / Path("checkpoints") @@ -150,46 +108,42 @@ def build(self): sam2 = build_sam2( model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False ) + points_per_batch = None + if self.task_type == "amg": + points_per_batch = 64 if self.baseline else 1024 + if self.task_type == "sps": + points_per_batch = 1 mask_generator = SAM2AutomaticMaskGenerator( - sam2, points_per_batch=1024, output_mode="uncompressed_rle" + sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle" ) - # from compile_export_utils import load_exported_model - # mask_generator = load_exported_model(mask_generator, - # Path(TARGET) / Path("exported_models"), - # # Currently task_type has no effect, - # # because we can only export the image - # # encoder, but this might change soon. - # "amg", # task_type - # furious=True, - # batch_size=1, - # points_per_batch=1024) - self.mask_generator = mask_generator - import os - import sys - - import numpy as np - import torch + from compile_export_utils import load_exported_model - os.chdir(Path(TARGET + "data")) - sys.path.append(".") - from server import ( - file_bytes_to_image_tensor, - masks_to_rle_dict, - profiler_runner, - show_anns, + export_model_path = Path(TARGET) / Path("exported_models") + export_model_path = ( + export_model_path / Path("sam2") / Path(f"sam2_{self.task_type}") ) + if not self.baseline: + load_exported_model( + mask_generator, + export_model_path, + self.task_type, + furious=True, + batch_size=1, + points_per_batch=points_per_batch, + ) + self.mask_generator = mask_generator from torchvision import io as tio from torchvision.transforms.v2 import functional as tio_F - from torchao._models.sam2.utils.amg import ( - area_from_rle, - mask_to_rle_pytorch_2, - rle_to_mask, - ) - - # Baselien - # from sam2.utils.amg import rle_to_mask - # from sam2.utils.amg import mask_to_rle_pytorch as mask_to_rle_pytorch_2 + if self.baseline: + from sam2.utils.amg import mask_to_rle_pytorch as mask_to_rle_pytorch_2 + from sam2.utils.amg import rle_to_mask + else: + from torchao._models.sam2.utils.amg import ( + mask_to_rle_pytorch_2, + rle_to_mask, + ) + from torchao._models.sam2.utils.amg import area_from_rle self.np = np self.tio = tio @@ -207,12 +161,26 @@ def build(self): self._get_center_point = _get_center_point - from generate_data import gen_masks_ao as gen_masks - # Baseline - # from generate_data import gen_masks_baseline as gen_masks + if self.baseline: + from generate_data import gen_masks_baseline as gen_masks + else: + from generate_data import gen_masks_ao as gen_masks self.gen_masks = gen_masks + def decode_img_bytes(self, img_bytes_tensor, baseline=False): + import torch + + image_tensor = self.file_bytes_to_image_tensor(img_bytes_tensor) + from torchvision.transforms import v2 + + if not self.baseline: + image_tensor = torch.from_numpy(image_tensor) + image_tensor = image_tensor.permute((2, 0, 1)) + image_tensor = image_tensor.cuda() + image_tensor = v2.ToDtype(torch.float32, scale=True)(image_tensor) + return image_tensor + @modal.web_endpoint(docs=True, method="POST") async def upload_rle(self, image): def upload_rle_inner(input_bytes): @@ -220,18 +188,17 @@ def upload_rle_inner(input_bytes): masks = self.mask_generator.generate(image_tensor) return self.masks_to_rle_dict(masks) - # return self.profiler_runner(TARGET + "traces/trace.json.gz", upload_rle_inner, bytearray(await image.read())) return upload_rle_inner(bytearray(await image.read())) @modal.method() def inference_amg_rle(self, input_bytes) -> dict: - image_tensor = self.file_bytes_to_image_tensor(input_bytes) + image_tensor = self.decode_img_bytes(input_bytes) masks = self.gen_masks("amg", image_tensor, self.mask_generator) return self.masks_to_rle_dict(masks) @modal.method() def inference_amg_meta(self, input_bytes) -> dict: - image_tensor = self.file_bytes_to_image_tensor(input_bytes) + image_tensor = self.decode_img_bytes(input_bytes) masks = self.gen_masks("amg", image_tensor, self.mask_generator) rle_dict = self.masks_to_rle_dict(masks) masks = {} @@ -249,7 +216,7 @@ def inference_sps_rle(self, input_bytes, prompts) -> dict: prompts = np.array(prompts) prompts_label = np.array([1] * len(prompts)) - image_tensor = self.file_bytes_to_image_tensor(input_bytes) + image_tensor = self.decode_img_bytes(input_bytes) masks = self.gen_masks( "sps", image_tensor, @@ -267,7 +234,7 @@ def inference_mps_rle(self, input_bytes, prompts) -> dict: prompts = np.array(prompts) prompts_label = np.array([1] * len(prompts)) - image_tensor = self.file_bytes_to_image_tensor(input_bytes) + image_tensor = self.decode_img_bytes(input_bytes) masks = self.gen_masks( "mps", image_tensor, @@ -313,7 +280,7 @@ def plot_image_tensor(self, image_tensor, masks, output_format, prompts=None): @modal.method() def inference_amg(self, input_bytes, output_format="png"): - image_tensor = self.file_bytes_to_image_tensor(input_bytes) + image_tensor = self.decode_img_bytes(input_bytes) masks = self.gen_masks("amg", image_tensor, self.mask_generator) return self.plot_image_tensor(image_tensor, masks, output_format) @@ -323,7 +290,7 @@ def inference_sps(self, input_bytes, prompts, output_format="png"): prompts = np.array(prompts) prompts_label = np.array([1] * len(prompts)) - image_tensor = self.file_bytes_to_image_tensor(input_bytes) + image_tensor = self.decode_img_bytes(input_bytes) masks = self.gen_masks( "sps", image_tensor, @@ -343,7 +310,7 @@ def inference_mps(self, input_bytes, prompts, output_format="png"): prompts = np.array(prompts) prompts_label = np.array([1] * len(prompts)) - image_tensor = self.file_bytes_to_image_tensor(input_bytes) + image_tensor = self.decode_img_bytes(input_bytes) masks = self.gen_masks( "mps", image_tensor, @@ -369,6 +336,17 @@ def get_center_points(task_type, meta_path): return center_points +def timed_print(msg): + from datetime import datetime + + current_time = datetime.now() + timestamp_with_nanoseconds = ( + current_time.strftime("%Y-%m-%d %H:%M:%S.") + + f"{current_time.microsecond * 1000:09d}" + ) + print(f"{str(timestamp_with_nanoseconds)}: {msg}") + + def main( task_type, input_paths, @@ -376,11 +354,13 @@ def main( output_rle=False, output_meta=False, meta_paths=None, + baseline=False, + name=None, ): assert task_type in ["amg", "sps", "mps"] if task_type in ["sps", "mps"]: assert meta_paths is not None - input_paths = open(input_paths).read().split("\n") + input_paths = open(input_paths).read().split("\n")[:-1] for input_path in input_paths: assert Path(input_path).exists() @@ -393,7 +373,7 @@ def main( if meta_paths is not None: meta_mapping = {} - meta_paths = open(meta_paths).read().split("\n") + meta_paths = open(meta_paths).read().split("\n")[:-1] for meta_path in meta_paths: assert Path(meta_path).exists() key = Path(meta_path).name.split("_meta.json")[0] @@ -401,7 +381,10 @@ def main( meta_mapping[key] = meta_path try: - model = modal.Cls.lookup("torchao-sam-2-cli", "Model")() + if name is None: + name = "torchao-sam-2-cli" + model = modal.Cls.lookup(name, "Model") + model = model(task_type=task_type, baseline=int(baseline)) except modal.exception.NotFoundError: print( "Can't find running app. To deploy the app run the following", @@ -411,44 +394,66 @@ def main( print("modal deploy cli_on_modal.py") return - print("idx,time(s)") - for idx, (input_path) in enumerate(input_paths): + outputs = [] + output_paths = [] + timed_print(f"Queueing {len(input_paths)} tasks...") + for input_path in input_paths: key = Path(input_path).name.split(".jpg")[0] key = f"{Path(input_path).parent.name}/{key}" if meta_paths is not None: meta_path = meta_mapping[key] center_points = get_center_points(task_type, meta_path) - start = time.perf_counter() input_bytes = bytearray(open(input_path, "rb").read()) output_path = output_directory / Path(key) + output_paths.append(str(output_path)) output_path.parent.mkdir(parents=False, exist_ok=True) if output_meta: assert task_type == "amg" - output_dict = model.inference_amg_meta.remote(input_bytes) - with open(f"{output_path}_meta.json", "w") as file: - file.write(json.dumps(output_dict, indent=4)) + outputs.append(model.inference_amg_meta.remote.aio(input_bytes)) elif output_rle: if task_type == "amg": - output_dict = model.inference_amg_rle.remote(input_bytes) + outputs.append(model.inference_amg_rle.remote.aio(input_bytes)) if task_type == "sps": - output_dict = model.inference_sps_rle.remote(input_bytes, center_points) + outputs.append( + model.inference_sps_rle.remote.aio(input_bytes, center_points) + ) if task_type == "mps": - output_dict = model.inference_mps_rle.remote(input_bytes, center_points) - with open(f"{output_path}_masks.json", "w") as file: - file.write(json.dumps(output_dict, indent=4)) + outputs.append( + model.inference_mps_rle.remote.aio(input_bytes, center_points) + ) else: if task_type == "amg": - output_bytes = model.inference_amg.remote(input_bytes) + outputs.append(model.inference_amg.remote.aio(input_bytes)) if task_type == "sps": - output_bytes = model.inference_sps.remote(input_bytes, center_points) + outputs.append( + model.inference_sps.remote.aio(input_bytes, center_points) + ) if task_type == "mps": - output_bytes = model.inference_mps.remote(input_bytes, center_points) + outputs.append( + model.inference_mps.remote.aio(input_bytes, center_points) + ) + + async def run_all(outputs): + outputs = await asyncio.gather(*outputs) + return outputs + + timed_print("Awaiting tasks...") + outputs = asyncio.run(run_all(outputs)) + + timed_print("Processing task output...") + for output, output_path in zip(outputs, output_paths): + if output_meta: + with open(f"{output_path}_meta.json", "w") as file: + file.write(json.dumps(output, indent=4)) + elif output_rle: + with open(f"{output_path}_masks.json", "w") as file: + file.write(json.dumps(output, indent=4)) + else: with open(f"{output_path}_annotated.png", "wb") as file: - file.write(output_bytes) - end = time.perf_counter() - print(f"{idx},{end - start}") + file.write(output) + timed_print("Done.") if __name__ == "__main__": diff --git a/examples/sam2_amg_server/compare_rle_lists.py b/examples/sam2_amg_server/compare_rle_lists.py index 841d1d9d8e..7a1c78b846 100644 --- a/examples/sam2_amg_server/compare_rle_lists.py +++ b/examples/sam2_amg_server/compare_rle_lists.py @@ -1,10 +1,26 @@ import json from pathlib import Path +from typing import Any, Dict import fire +import numpy as np import torch -from torchao._models.sam2.utils.amg import rle_to_mask + +# from torchao._models.sam2.utils.amg import rle_to_mask +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + """ Script to calculate mIoU given two lists of rles from upload_rle endpoint @@ -20,6 +36,10 @@ def iou(mask1, mask2): return intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)) +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + def compare_masks(masks, ref_masks, order_by_area=False, verbose=False): v0_areas = [] v1_areas = [] @@ -27,8 +47,6 @@ def compare_masks(masks, ref_masks, order_by_area=False, verbose=False): v1_masks = [] for k0 in ref_masks: assert k0 in masks, f"Expected {k0} to be in return data" - from torchao._models.sam2.utils.amg import area_from_rle - v0_area = area_from_rle(ref_masks[k0]) v1_area = area_from_rle(masks[k0]) v0_areas.append(v0_area) diff --git a/examples/sam2_amg_server/modal_experiments.sh b/examples/sam2_amg_server/modal_experiments.sh index fd9411822f..2d7d8c1ab2 100755 --- a/examples/sam2_amg_server/modal_experiments.sh +++ b/examples/sam2_amg_server/modal_experiments.sh @@ -2,28 +2,45 @@ set -ex -# outputdir="/Users/cpuhrsch/blogs/tmp/sam2_amg_example_run_1" -# while IFS= read -r filepath; do -# filename=$(basename "$filepath") -# dirname=$(basename "$(dirname "$filepath")") -# mkdir -p "${outputdir}"/"${dirname}" -# echo curl -w "\"%{time_total}s\\\\n\"" -s -X POST https://cpuhrsch--torchao-sam-2-cli-model-upload-rle.modal.run -F "image=@${filepath}" -o "${outputdir}"/"${dirname}"/"${filename}.json" -# echo "${filepath}" >> cmds_input_paths -# echo "${outputdir}"/"${dirname}"/"${filename}.json" >> cmds_output_paths -# done < ~/data/sav_val_image_paths_shuf_1000 - -# time python cli_on_modal.py --task-type amg --input-paths ~/blogs/cmds_input_paths --output_directory /Users/cpuhrsch/blogs/tmp/sam2_amg_example_run_1_amg --output-rle False --meta-paths ~/blogs/cmds_meta_paths -# time python cli_on_modal.py --task-type sps --input-paths ~/blogs/cmds_input_paths --output_directory /Users/cpuhrsch/blogs/tmp/sam2_amg_example_run_1_sps --output-rle False --meta-paths ~/blogs/cmds_meta_paths -# time python cli_on_modal.py --task-type mps --input-paths ~/blogs/cmds_input_paths --output_directory /Users/cpuhrsch/blogs/tmp/sam2_amg_example_run_1_mps --output-rle False --meta-paths ~/blogs/cmds_meta_paths - -# # amg -# modal deploy cli_on_modal.py -# time python cli_on_modal.py --task-type amg --input-paths ~/blogs/cmds_input_paths --output_directory ~/blogs/tmp/sam2_amg_example_run_1_amg --output-rle True --meta-paths ~/blogs/cmds_meta_paths | tee ~/blogs/amg_latencies - -# # sps -# modal deploy cli_on_modal.py -# time python cli_on_modal.py --task-type sps --input-paths ~/blogs/cmds_input_paths --output_directory ~/blogs/tmp/sam2_amg_example_run_1_sps --output-rle True --meta-paths ~/blogs/cmds_meta_paths | tee ~/blogs/sps_latencies +# amg baseline +modal deploy cli_on_modal.py --name torchao-sam-2-cli-amg-baseline +mkdir -p ~/blogs/outputs/amg_baseline +time python cli_on_modal.py --task-type amg --input-paths ~/blogs/cmds_input_paths --output_directory ~/blogs/outputs/amg_baseline --output-rle True --meta-paths ~/blogs/cmds_meta_paths --name torchao-sam-2-cli-amg-baseline --baseline +modal app stop torchao-sam-2-cli-amg-baseline + +# sps baseline +modal deploy cli_on_modal.py --name torchao-sam-2-cli-sps-baseline +mkdir -p ~/blogs/outputs/sps_baseline +time python cli_on_modal.py --task-type sps --input-paths ~/blogs/cmds_input_paths --output_directory ~/blogs/outputs/sps_baseline --output-rle True --meta-paths ~/blogs/cmds_meta_paths --name torchao-sam-2-cli-sps-baseline --baseline +modal app stop torchao-sam-2-cli-sps-baseline + +# mps baseline +modal deploy cli_on_modal.py --name torchao-sam-2-cli-mps-baseline +mkdir -p ~/blogs/outputs/mps_baseline +time python cli_on_modal.py --task-type mps --input-paths ~/blogs/cmds_input_paths --output_directory ~/blogs/outputs/mps_baseline --output-rle True --meta-paths ~/blogs/cmds_meta_paths --name torchao-sam-2-cli-mps-baseline --baseline +modal app stop torchao-sam-2-cli-mps-baseline + +# amg +modal deploy cli_on_modal.py --name torchao-sam-2-cli-amg +mkdir -p ~/blogs/outputs/amg +time python cli_on_modal.py --task-type amg --input-paths ~/blogs/cmds_input_paths --output_directory ~/blogs/outputs/amg --output-rle True --meta-paths ~/blogs/cmds_meta_paths --name torchao-sam-2-cli-amg +modal app stop torchao-sam-2-cli-amg + +# sps +modal deploy cli_on_modal.py --name torchao-sam-2-cli-sps +mkdir -p ~/blogs/outputs/sps +time python cli_on_modal.py --task-type sps --input-paths ~/blogs/cmds_input_paths --output_directory ~/blogs/outputs/sps --output-rle True --meta-paths ~/blogs/cmds_meta_paths --name torchao-sam-2-cli-sps +modal app stop torchao-sam-2-cli-sps # mps -modal deploy cli_on_modal.py -time python cli_on_modal.py --task-type mps --input-paths ~/blogs/cmds_input_paths --output_directory ~/blogs/tmp/sam2_amg_example_run_1_mps --output-rle True --meta-paths ~/blogs/cmds_meta_paths | tee ~/blogs/mps_latencies +modal deploy cli_on_modal.py --name torchao-sam-2-cli-mps +mkdir -p ~/blogs/outputs/mps +time python cli_on_modal.py --task-type mps --input-paths ~/blogs/cmds_input_paths --output_directory ~/blogs/outputs/mps --output-rle True --meta-paths ~/blogs/cmds_meta_paths --name torchao-sam-2-cli-mps +modal app stop torchao-sam-2-cli-mps + +echo "amg vs baseline" +python compare_rle_lists.py ~/blogs/outputs/amg ~/blogs/outputs/amg_baseline --compare-folders --strict +echo "sps vs baseline" +python compare_rle_lists.py ~/blogs/outputs/sps ~/blogs/outputs/sps_baseline --compare-folders --strict +echo "mps vs baseline" +python compare_rle_lists.py ~/blogs/outputs/mps ~/blogs/outputs/mps_baseline --compare-folders --strict From bae41d174ad206be3f853414dd0055c552fde0fe Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 10 Feb 2025 12:05:01 -0800 Subject: [PATCH 065/115] mx: add ceil and RNE rounding modes to the cast from fp32 to e8m0 (#1643) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 18 +++++- torchao/prototype/mx_formats/mx_tensor.py | 71 ++++++++++++++++++--- 2 files changed, 76 insertions(+), 13 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 21cb49c064..ad718beb9c 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -18,6 +18,7 @@ from torchao.prototype.mx_formats.mx_tensor import ( E8M0_EXPONENT_NAN_VAL, MXTensor, + ScaleCalculationMode, to_dtype, ) from torchao.quantization.utils import compute_error @@ -47,8 +48,10 @@ def run_before_and_after_tests(): torch._dynamo.reset() -def _test_mx(data_hp, elem_dtype, block_size): - data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size) +def _test_mx( + data_hp, elem_dtype, block_size, scale_calculation_mode=ScaleCalculationMode.FLOOR +): + data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size, scale_calculation_mode) data_mx_dq = data_mx.to_dtype(data_hp.dtype) def assert_sqnr_gt_threshold(orig, new, threshold): @@ -61,7 +64,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold): assert sqnr >= threshold if elem_dtype is torch.float8_e4m3fn: - assert_sqnr_gt_threshold(data_hp, data_mx_dq, 20.0) + assert_sqnr_gt_threshold(data_hp, data_mx_dq, 18.0) else: assert_sqnr_gt_threshold(data_hp, data_mx_dq, 14.0) @@ -74,6 +77,15 @@ def test_hello_world(elem_dtype): _test_mx(data, elem_dtype, block_size) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("scale_calculation_mode", [s for s in ScaleCalculationMode]) +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_realistic_numerics(elem_dtype, scale_calculation_mode): + data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + block_size = 32 + _test_mx(data, elem_dtype, block_size, scale_calculation_mode) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_all_zeros(elem_dtype): diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 8eeeaf8bfd..801f29ac3c 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -16,6 +16,7 @@ * Zeros: N/A """ +from enum import Enum, auto from typing import Dict, Union import torch @@ -53,11 +54,38 @@ unpack_uint4, ) +# TODO(later): read from somewhere else? +SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 +EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 +EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 +EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 +EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3 +EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2 + + +class ScaleCalculationMode(Enum): + """ + Enum representing the different methods for calculating MX block scaling. + There are three methods available: + FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp). + It result in overflow issues for large values and bad for gradient quantization. + CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor. + It uses X = 2^ceil(log2(max_abs(v))-max_exp). + EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)). + It provides better accuracy for MX4 training compared to FLOOR and CEIL. + By default, we use the EVEN method for better accuracy. + """ + + FLOOR = auto() + CEIL = auto() + EVEN = auto() + def to_mx( data_hp: torch.Tensor, elem_dtype: Union[torch.dtype, str], block_size: int, + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, ): """ Takes a high precision tensor and converts to MX scale and raw data, in @@ -88,25 +116,45 @@ def to_mx( # where the values are zero. eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype) - # Find largest power of 2 less than or equal to max_abs. - largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps)) - # Set X to be the largest power-of-two less than or equal to # max_abs(v), divided by the largest power of two representable - # in the element data type + # in the element data type, and get the mbits at the same time if elem_dtype == torch.float8_e4m3fn: target_max_pow2 = F8E4M3_MAX_POW2 + mbits = MBITS_F8_E4M3 elif elem_dtype == torch.float8_e5m2: target_max_pow2 = F8E5M2_MAX_POW2 + mbits = MBITS_F8_E5M2 elif elem_dtype == DTYPE_FP6_E2M3: target_max_pow2 = F6_E2M3_MAX_POW2 + mbits = MBITS_F6_E2M3 elif elem_dtype == DTYPE_FP6_E3M2: target_max_pow2 = F6_E3M2_MAX_POW2 + mbits = MBITS_F6_E3M2 elif elem_dtype == DTYPE_FP4: target_max_pow2 = F4_E2M1_MAX_POW2 + mbits = MBITS_F4_E2M1 else: - raise AssertionError("unsupported") - scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2 + raise AssertionError("unsupported element dtype") + + # rounding before calculating the largest power of 2 + # X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)) + if scaling_mode == ScaleCalculationMode.EVEN: + nan_mask = torch.isnan(max_abs) + max_abs = max_abs.to(torch.float32).view(torch.int32) + val_to_add = 1 << (MBITS_F32 - mbits - 1) + mask = ((1 << (EBITS_F32 + SBITS)) - 1) << MBITS_F32 + max_abs = (max_abs + val_to_add) & mask + max_abs = max_abs.view(torch.float32) + max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device) + + # Calculate the scale for different modes + if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN): + scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - target_max_pow2 + elif scaling_mode == ScaleCalculationMode.CEIL: + scale_e8m0_unbiased = torch.ceil(torch.log2(max_abs + eps)) - target_max_pow2 + else: + raise AssertionError("unsupported scaling calculation mode") # Clamp to exponents that can be represented in e8m0 scale_e8m0_unbiased = torch.clamp( @@ -270,15 +318,17 @@ class ToMXConstrFunc(torch.autograd.Function): """ @staticmethod - def forward(ctx, data_hp, elem_dtype, block_size): - scale_e8m0_biased, data_lp = to_mx(data_hp, elem_dtype, block_size) + def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode): + scale_e8m0_biased, data_lp = to_mx( + data_hp, elem_dtype, block_size, scaling_mode + ) return MXTensor( scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype ) @staticmethod def backward(ctx, g): - return g, None, None + return g, None, None, None @torch._dynamo.allow_in_graph @@ -392,8 +442,9 @@ def to_mx( data_hp: torch.Tensor, elem_dtype: Union[torch.dtype, str], block_size: int = BLOCK_SIZE_DEFAULT, + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, ): - return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size) + return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode) def __tensor_flatten__(self): ctx = { From 32a51eca14257bbaafd3671a5349189e30c65e2b Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 10 Feb 2025 12:08:44 -0800 Subject: [PATCH 066/115] Support power of 2 scaling factors in float8 training and use e4m3 everywhere (#1670) --- test/float8/test_base.py | 6 ++- test/float8/test_compile.py | 20 +++++--- test/float8/test_float8_utils.py | 65 ++++++++++++++++++++++++++ torchao/float8/config.py | 21 +++++++-- torchao/float8/float8_linear.py | 6 +++ torchao/float8/float8_scaling_utils.py | 4 ++ torchao/float8/float8_utils.py | 44 ++++++++++++----- 7 files changed, 145 insertions(+), 21 deletions(-) create mode 100644 test/float8/test_float8_utils.py diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 3e894c02b9..b537c7ab9f 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -164,7 +164,10 @@ def test_transpose(self): @pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)]) @pytest.mark.parametrize("axiswise_dim", [0, -1]) - def test_axiswise_dynamic_cast(self, shape, axiswise_dim): + @pytest.mark.parametrize("round_scales_to_power_of_2", [True, False]) + def test_axiswise_dynamic_cast( + self, shape, axiswise_dim, round_scales_to_power_of_2 + ): a = torch.randn(*shape, dtype=torch.bfloat16) linear_mm_config = LinearMMConfig() a_fp8 = hp_tensor_to_float8_dynamic( @@ -173,6 +176,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim): linear_mm_config, scaling_granularity=ScalingGranularity.AXISWISE, axiswise_dim=axiswise_dim, + round_scales_to_power_of_2=round_scales_to_power_of_2, ) a_dq = a_fp8.to_original_precision() sqnr = compute_error(a, a_dq) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index c42ab8ee77..d9c71f7395 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -45,11 +45,7 @@ hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, ) -from torchao.float8.float8_tensor import ( - GemmInputRole, - LinearMMConfig, - ScaledMMConfig, -) +from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig from torchao.float8.float8_utils import config_has_stateful_scaling from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -420,13 +416,23 @@ def test_sync_amax_func_cuda_graph_success(): torch.float16, ], ) -def test_dynamic_scale_numeric_parity(dtype: torch.dtype): +@pytest.mark.parametrize( + "round_scales_to_power_of_2", + [ + True, + False, + ], +) +def test_dynamic_scale_numeric_parity( + dtype: torch.dtype, round_scales_to_power_of_2: bool +): scaling_type_weight = ScalingType.DYNAMIC torch.manual_seed(42) hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) hp_tensor2 = hp_tensor1.detach().clone() float8_config = Float8LinearConfig( cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + round_scales_to_power_of_2=round_scales_to_power_of_2, ) linear_mm_config = LinearMMConfig( # output @@ -456,6 +462,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, ) torch._dynamo.reset() float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( @@ -463,6 +470,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): e4m3_dtype, linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, + round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2, ) assert torch.equal(float8_eager._scale, float8_compile._scale) assert torch.equal(float8_eager._data, float8_compile._data) diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py new file mode 100644 index 0000000000..ca9f21dde1 --- /dev/null +++ b/test/float8/test_float8_utils.py @@ -0,0 +1,65 @@ +import unittest + +import pytest +import torch + +from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + +if not TORCH_VERSION_AT_LEAST_2_5: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +# source for notable single-precision cases: +# https://en.wikipedia.org/wiki/Single-precision_floating-point_format +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +@pytest.mark.parametrize( + "test_case", + [ + # ("test_case_name", input, expected result) + ("one", 1.0, 1.0), + ("inf", float("inf"), float("inf")), + ("nan", float("nan"), float("nan")), + ("smallest positive subnormal number", 2**-126 * 2**-23, 2**-126 * 2**-23), + ("largest normal number", 2**127 * (2 - 2**-23), float("inf")), + ("smallest positive normal number", 2**-126, 2**-126), + ("largest number less than one", 1.0 - 2**-24, 0.5), + ("smallest number larger than one", 1.0 + 2**-23, 1.0), + # TODO(danielvegamyhre): debug why creating a tensor with largest + # subnormal value in CI env for pytorch 2.5.1 truncates the value to 0. + # ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]), + ], +) +def test_round_scale_down_to_power_of_2_valid_inputs( + test_case: dict, +): + test_case_name, input, expected_result = test_case + input_tensor, expected_tensor = ( + torch.tensor(input, dtype=torch.float32).cuda(), + torch.tensor(expected_result, dtype=torch.float32).cuda(), + ) + result = _round_scale_down_to_power_of_2(input_tensor) + + assert ( + torch.equal(result, expected_tensor) + or (result.isnan() and expected_tensor.isnan()) + ), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}" + + +@pytest.mark.parametrize( + "invalid_dtype", + [ + torch.bfloat16, + torch.float16, + torch.float64, + torch.int8, + torch.uint8, + torch.int32, + torch.uint32, + torch.int64, + ], +) +def test_non_float32_input(invalid_dtype: torch.dtype): + non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype) + with pytest.raises(AssertionError, match="scale must be float32 tensor"): + _round_scale_down_to_power_of_2(non_float32_tensor) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index fb306e0fb7..b971ff31b0 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -234,6 +234,13 @@ class Float8LinearConfig: # tests so that the warning does not spam the CI stdout. force_recompute_fp8_weight_in_bwd: bool = False + # If this option is enabled, the scaling factor used for float8 quantization + # will be rounded down to the nearest power of 2. This has been shown to help + # reduce quantization error by avoiding rounding errors when multiplying/dividing + # by the scaling factor, as well as ensuring large values are quantized to the + # same value in the forward pass as the backward passes. + round_scales_to_power_of_2: bool = False + def __post_init__(self): # Populate the additional cast overrides, if the user did not specify them # Note: this hacks around the frozen-ness of this dataclass @@ -338,14 +345,22 @@ def recipe_name_to_linear_config( elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE: # dynamic axiswise scaling with the CUTLASS rowwise kernel - cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_i = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_w = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) return Float8LinearConfig( cast_config_input=cc_i, cast_config_weight=cc_w, cast_config_grad_output=cc_go, + # enable power of 2 scaling factors by default for row-wise scaling + round_scales_to_power_of_2=True, ) elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 6b3c0f06df..0bc2690bc5 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -96,6 +96,7 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_input.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -112,6 +113,7 @@ def forward( axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_weight.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) # the reshapes are needed in order to make the shapes compatible with @@ -151,6 +153,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_grad_output.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) if tensor_already_casted_to_fp8(weight_hp_t): @@ -181,6 +184,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( -1, c.cast_config_weight_for_grad_input.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) grad_input = torch.mm( @@ -216,6 +220,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_grad_output_for_grad_weight.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) if tensor_already_casted_to_fp8(input_hp_reshaped): @@ -233,6 +238,7 @@ def backward(ctx, grad_output): axiswise_dim=get_maybe_axiswise_dim( 0, c.cast_config_input_for_grad_weight.scaling_granularity ), + round_scales_to_power_of_2=c.round_scales_to_power_of_2, ) grad_weight = torch.mm( diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 0c27e4f3fc..b96c7a9b58 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -27,6 +27,7 @@ ) +# TODO(danielvegamyhre): refactor to accept Float8LinearConfig directly def hp_tensor_to_float8_dynamic( hp_tensor: torch.Tensor, float8_dtype: torch.dtype, @@ -36,6 +37,7 @@ def hp_tensor_to_float8_dynamic( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + round_scales_to_power_of_2: bool = False, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -51,6 +53,7 @@ def hp_tensor_to_float8_dynamic( the 3 fwd/bwd gemms of linear scaling_granularity: Defines the scaling granularity axiswise_dim: if axiswise granularity is used, defines the dim to scale across + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. """ scale = tensor_to_scale( hp_tensor, @@ -59,6 +62,7 @@ def hp_tensor_to_float8_dynamic( device_mesh, scaling_granularity, axiswise_dim, + round_scales_to_power_of_2, ) return hp_tensor_and_scale_to_float8( hp_tensor, diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 6a93a612fa..926b97edb8 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -10,11 +10,7 @@ import torch.distributed as dist from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import ( - Float8LinearConfig, - ScalingGranularity, - ScalingType, -) +from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -33,21 +29,28 @@ @torch.no_grad() -def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): +def amax_to_scale( + amax: torch.Tensor, + float8_dtype: torch.dtype, + round_scales_to_power_of_2: bool = False, +): """Converts the amax value of a tensor to the fp8 scale. Args: amax: The amax value of the tensor. float8_dtype: The float8 dtype. + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. """ # torch.compile and eager show different numerics for 1.0 / float32, # upcast to float64 to ensure same numeric between compile and eager amax = amax.to(torch.float64) if float8_dtype in FP8_TYPES: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) + res = res.to(torch.float32) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") - - return res.to(torch.float32) + if round_scales_to_power_of_2: + res = _round_scale_down_to_power_of_2(res) + return res @torch.no_grad() @@ -119,21 +122,35 @@ def tensor_to_amax( @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, + hp_tensor: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False, device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + round_scales_to_power_of_2: bool = False, ) -> torch.Tensor: + """ + Compute scaling factor for the given high precision tensor. + + Args: + hp_tensor: high precision tensor + float8_dtype: the float8 dtype to use + reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks + scaling_granularity: Defines the scaling granularity + axiswise_dim: if axiswise granularity is used, defines the dim to scale across + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. + """ amax = tensor_to_amax( - x, + hp_tensor, reduce_amax, device_mesh, scaling_granularity, axiswise_dim, ) - return amax_to_scale(amax, float8_dtype) + return amax_to_scale( + amax, float8_dtype, round_scales_to_power_of_2=round_scales_to_power_of_2 + ) def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): @@ -266,3 +283,8 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC ) + + +def _round_scale_down_to_power_of_2(scale: torch.Tensor): + assert scale.dtype == torch.float32, "scale must be float32 tensor" + return torch.exp2(torch.floor(torch.log2(scale))) From 999b16db6380cb7dc08ba5779f230206471b3120 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Mon, 10 Feb 2025 16:19:25 -0800 Subject: [PATCH 067/115] Add third_party to exclude (#1692) --- ruff.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ruff.toml b/ruff.toml index a4ac551476..10969fed6b 100644 --- a/ruff.toml +++ b/ruff.toml @@ -2,3 +2,9 @@ # Add linting rules here lint.select = ["F", "I"] lint.ignore = ["E731"] + + +# Exclude third-party modules +exclude = [ + "third_party/*", +] From d99785c0fdaa1dbfdbaf57923326edf2b8a7f1f8 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 11 Feb 2025 10:44:56 -0800 Subject: [PATCH 068/115] Update float8nocompile readme (#1693) --- torchao/prototype/float8nocompile/README.md | 73 +++++++++++++++++- .../float8nocompile_loss_curves.png | Bin 0 -> 94660 bytes 2 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 torchao/prototype/float8nocompile/float8nocompile_loss_curves.png diff --git a/torchao/prototype/float8nocompile/README.md b/torchao/prototype/float8nocompile/README.md index 87ced9fddc..4723ff9e60 100644 --- a/torchao/prototype/float8nocompile/README.md +++ b/torchao/prototype/float8nocompile/README.md @@ -1,3 +1,72 @@ -# Work in progress +# float8nocompile -A prototype version of Float8Linear which is performant without `torch.compile`. + +A prototype API for high performance eager mode float8 training that uses handwritten Triton kernels for quantization. + +### Usage + +Prepare your model for high performance eager mode float8 training with a single conversion function: `convert_to_float8_nocompile_training` ([source](https://github.com/pytorch/ao/blob/32a51eca14257bbaafd3671a5349189e30c65e2b/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py#L24)). + +This function will replace nn.Linear layers with Float8NoCompileLinear layers in-place, which uses **dynamic, tensorwise scaling** +to perform all matmuls in the linear layer forward and backward pass as FP8 GEMMs. + +**Example**: + +```python +from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( + convert_to_float8_nocompile_training, +) + +# define your model, data loaders, etc +... + +# convert specified `torch.nn.Linear` modules to `Float8Linear` +convert_to_float8_nocompile_training(model) + +# training loop +for i in range(num_epochs): + ... +``` + +### Performance benchmarks + +Performance benchmarking was done via [experimental integration into torchtitan](https://github.com/pytorch/torchtitan/pull/778). + +The results indicate a solid 6-10% tokens/sec speedup with relatively flat memory (+/- 1% peak memory) compared the bf16 eager baseline. + +# Performance Comparison of Different Configurations on 8 H100s + +## No AC (seq len 4096) - 8 H100s + +| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ | +|-------------------------------------------------|------------|------------------|--------------|---------------| +| bfloat16, eager | 5339.0 | 53.12 | 0% | 0.00% | +| float8nocompile prototype | 5871.4 | 52.7 | 9.97% | -0.79% | +| float8 + torch.compile | 6667.6 | 46.64 | 24.88% | -12.20% | + +--- + +## Selective per layer AC (AC every 2nd layer, seq len 4096) - 8 H100s + +| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ | +|-------------------------------------------------|------------|------------------|--------------|---------------| +| bfloat16, eager | 4882.4 | 40.6 | 0% | 0.00% | +| float8nocompile prototype | 5302.0 | 40.97 | 8.59% | 0.91% | +| float8 + torch.compile | 6199.6 | 37.38 | 26.98% | -7.93% | + +--- + +## Full AC (seq len 4096) - 8 H100s + +| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ | +|-------------------------------------------------|------------|------------------|--------------|---------------| +| bfloat16, eager | 4502.0 | 28.07 | 0% | 0.00% | +| float8nocompile prototype | 4773.4 | 28.07 | 6.03% | 0.00% | +| float8 + torch.compile | 5775.2 | 28.03 | 28.28% | -0.14% | + + +## Numerical accuracy + +Numerical accuracy has been verified via unit tests as well as manually verifying that the training loss curves maintain fidelity with the loss curves for bf16 eager and production float8 + torch.compile: + +![loss curves](float8nocompile_loss_curves.png "Loss curves") diff --git a/torchao/prototype/float8nocompile/float8nocompile_loss_curves.png b/torchao/prototype/float8nocompile/float8nocompile_loss_curves.png new file mode 100644 index 0000000000000000000000000000000000000000..a136512b9d15b78c7715703dbb3607630e590664 GIT binary patch literal 94660 zcmZ^K1yEc~*Cy`n4uiY9yK5l0I}Gmb8r*`r1cE!k-Q6Kr&>+EGck;g9FST1cRkv=> zZ8@jAPxq02I$Bj(1{r|>0SpWbSx#2+6Broy>E8zq2K43PrnwOGN6SV+LRC&ef=t!L z(bC4w0t}2c$v#m;VL%mYNX51Utr^|@OqHnQD}A!&7qd9B5Q*}S&3Vo;aST$MFCHG) zP_Q!2*O;^hP(xz|s#k=(h+^ebTPxq~mtNX7NF!ahJ1!v5)vP9QzSQsg&tv}ZOkcQmzN_Of^STLnzeiyxG$5E*50m%&g3;Z6bogsCo(D~|y1o!(Wann)<*jU^YzJ)Cc4SlU zPv%nW!0|%F2Z3ITOb`|AX_L^dP(WJQZ-h$H_>CJXe`T*DkjT7g-l|6~qY+L9nf< z?(aHy)Q}SL!yP^SZ&{xfFb>4Wdmj+OLvn@P*A0d58<3Ha!*>8+d1C8x&T)H_bL;E2 zKU?e@?@m|rs8cdyC2mRIq?0SUFM6!vMF)kt#mwRZ6@KMZr9)VJM1Z)RBm}3IAB`>6 zW{kM$gsg@KQ=-eB(p6AU;Jukz)cEc`<$Ke$FQD;SJWzZeZq;M_@wWY8|4lg;p@fhN zIGto?Rt(|&ge({MlpgrQe=OZEy1BI#*3?Lm$+7!NY4;_#oS~BlJ}W}sY5sfmCv0Rr z9#eiMb3sriF_cgz{&o^NSa#9h-``_*5MIFJEF)Lg-=~j3DyqGs<0rS{mfg=*$JWo4 znoN&8mI!F(6W{=e#iR^x5!;_RxS@EK^kx)KkVj>=7UIXuUC#8!6$QZ)B@`znY=xGE3Yvwi(wOLz-8$Dm`VUUCrQ8#Y(Q@! z(^u5?(T~T3k5y)ac5_t0OVQcyzCJ!xjg2>1du1^IPbO}Tu6XCniVvS_9zj7nnFNw% z|Cjr<-@nW4*BXZI&(<&ryp22Ycv-$!Dv|HtNbLB?*V*d)?kC8nB@CNg&mE_tawVrR z0P@bW#-|Or#XW>ZW{E60SG^nT*^f`0Ts{%EKbv==8;q*B@yM0=c zF^7$%gareCbru?Uy1A{|w{CK0s(ef4IxI{K{7j`LY{skpbY$r2>RKK)w?31@=Q?%y z4biOTA&Dthj<94RSTn&4$!8>&$E- znKB!3)X3Yzwqu9KAB6mU{PU0Zi;uG(M|s(bgy-t__@RLd*;`m!s3eu|gPxTROk*?J z0XB+&hZGfI8#6OAKrO4AIHY6!M(3E1lM^dkyTfL?@jefsHKf5g< zj{}l-j5B$&&NrW~n)6@<8+~3Bq;Te+gbrSg;powhZf2iiKlXqzuNY~xCFfd-&vs@J zkRbeQJ^C_AWKkkqpd4*Zy#2C^+k-d!+srSTj9=A{{7@w@y%_26kWjcT)!^2 zAj$K3e{2j8v)-f&x5koIkZyz$PHR~(gys=X&O8Pj`oyL&Yu=E-2NWS< zE9@^km=+6f&_>b0+b+hQ_AE=Oe12)dF!DaY_WgL-Cfz#d0uEI?8z7m!-x+;y#nNC* z+YxU^qh$76hO?~7h}^54cC1=eG)?+_r?2^4yER?HRwDk2O3ZKmajWuSZZ(rKU1=FnThd0m$X4c;aBFHCB731k;Da z55D-p8Ie9(cQBA_1PR9STL1$*b!21-_kC5*&jLP98&jmgtCwZ(qNjt-@$Qm-@$T;2 zOsB_m(yd;tB*X)`)lAhUcrXL(SVT=gtmYv%MDb-xu<&3GUA*{hZM*rUiNpQ_n1ze~zjP8mx z&^p#y?6GeSIkc+Te8AU@4wJ@QRGHpP_rlFn_-;{MOu34Uj};D}FPk!W@M1NSi@Yf| zs#Mz7cfq<`?U;Z)EUrboEh@YVPrZ~OX6RNqUmBTXJuz5FfcVOB9&rT;gBq4%KONyj zhT4SO-6t?eg@v%B!`65u{O(19i1hHtt;?g!l9^(d`Y6nMgA(sK+aRZP4T;~b+epdy zYE~llZmgdqJ#kOilD?-#*EUCL@(EmPB+0iqo)w-UVxm>ftm*a8&NcC(w+|8!D766> z>NS*y?!z~#$ElHsT$L1hD=8rXGamtiwz;|RLl#dk;c@s;G7eRTB0da%`&aB)>>34b z{8vA&h>had#2)eYX>rO z8JsVK-ryQGte*kIN)9Ah4efLJgF|Fx4NM|h2u`s!^)CBT(in= zrXUFV7E+9+ieKo=jtv3Pe|6@X|Ds>MTG2+Uow!5nxq%*3KqVdibi1Ge+mi?d1q;3z z@bV{^v5dZC30}7JA#=u$8jUUt{z@JEnxyPLXo|-g2@RIs!@Tq8+*e$>7)%u(BR>=t zHR1tqo$c7Y=J4a-D;etHOCrsc-*IU=M6lX1W?;K3nf>VQd@q&n%g-&6-DoxFXeB@R zi^VbZ`&ZojR!CSBsQ%zf0AIV->|hdc5vAY0jHN7r2%g>9#t&Ea^%6+rWEzq{a>Zlq z)&K)9m|%(-E0LGA+0h`{N0j#t{*JSyhi?8vkrm@$eRGgqbQwj%>#U?Zr(IB)1B2of z!j*=yU0@;+4w)%>;x>@tV7*zwB^lpiT_C3GxYjgu-gU7>Cc#(-H+^wh%8`bRoQP}a z1u&L01PGB?%pWi7IUsMBfe*CG`f@T8u_wwzVTQ@(L^Euz3PM`VCrhI^|2X~Z3Ela- zU!q1EM%?&$fD11KwrB6Aun!<`KVCl@C9dftibH^PI~pb(Zcdgv3cQUC8_wHRF)ZwJ znh|;k6CCN-kED4J!Y2qW!4AM9j1A~z3J@4?!I()Tu#w(sHxvd(eGO3>jP%2}xz%`N zgpN?lBTE@ge-68rX#XT_))^x8;`QSp#zjV$tVfkbbvh}nGcx94kay^D_;iT>#^A*1 zq>Rw9^XkjY;<($8gaH|a=KEkKzek06P1>IBim^T|j$K_K>mpyO>&YjAMW>&p2kG7^ zTssn!8t1&drA1F5`|-^QwV{wqbuJgnB{SYTBJi3f5Mvm&oUFbr{&IR@nvZoxU#-;5LzdoJeHQ^U=3O4PGxlw@J442sQj3OAC?=AI%70f3bOVfNaDd zjsX}n@?$ZtNL1n$5ucNUGoHX3o`GAGT{VL%Ct~8`&!=Rkp|B|HAPxh={lcl}8bU=V zAEvfnrwSno>w^^sogbK!nz0xNQ1ahzzd@=69%uzv%v!%6I=)j+9+!%j^gq2GiI5V2 zqnTKf{n@e@>Q6J*XIQ{qc}9_Z3r6+peSgaTSd;Zc;7wyRm&i&wvgw>tziAD8X_XF* zmq*o4%67SD=(%Pnbe6Puk5=*Pe?KsKr>fGG^7gU$P%z;=ISM#G5{ERV-ggg=h3_CE zP23(`sYh4IGD<*;qV`7Azfv$G@U!_HbmhXkKdh!)E;5?3C_{_6Z#3>f25#a1N=D-y zB~5xWoxl1ZSq^Lc`pYsz?6M6mFbZ*Rv8`hKq*+*wsNw@Zx16}BZP{1h2??LB_7ay! z{E6y9^{7B{S6r8=PW6Tw$c}<7HK~|Lmj;J_eiZnABypfTTxj|pZuAKK@;D+wLQfY3 zhY{Y}edOCsV=l{1OMpUvJ8`I{V-nkZbC`So5%Bv%M**l%cOoeuekxWOlj+K@Kgu`? zHzrzT^jvC0uwOd=ngDVFR7c0OHTpSu$v8N1_QRn7YMj0Vcp}#nA{e8@f4+|A)LeM> zylmbxZC^9}B1U*4+HgT{^eunN4lf}~?xzFIX)2>%i?4)H_=(GAP?x?cOH z>m?0*DvTsymyNru$T9M}8dCKN+@%R1A0%uD>*ZU>27kiXLg$K&ClWG!=zPu-UF2jh zU!;8QTweKgj$`}+jQc|J^Bfk2LXhS_8FFldsP?>W;QRpX; zj~kJ`g4hTYT|ht|!Myj+pIbQWtOWPlZ=2iUyvZ<$=vQUQPtUa<&!W;m&06@hg_+mM znXuakq$h5dLBG$r-Rjs-nivEGBnX2(+x`#XoIAw()TnGU*Pm;B(z*`0YU=}EM*@+Y z5we`Av9Iw5BaL57~gh0jWd_|c)Ory5!IDZ!;9Bxa-Xm3uULophr<8#>mAq~;Ztc;i9`87lJF@p( zbFKd$?Aj?#|5`5_6x@9pqC4tw@Hmg*+hH0ToX8HT*U3d#GsigX&UiDW} z;X#KoL`eyGP_N{Azfm+N@+x#W9pF;Dg2A(BXQFww%VN^E(A(XT5}H0*9A8>k&wcAZ zio{qG(U@!qs6bXy%w;0gA=A7oe`J+)A?No;rB`@F#wdZOpt{5q&UUBwkX(NMam`vC zVuApT+uQQQ8cypdn#XxUiQ>wxMxp(KjC?o-TZx6SL4WfjQ5+`-R2&~L()kc}@e6p( z+j2bL0k#lYWz^J@g%K3mnhiACc7%BOJS*X0@+%8Fwt6 zb?Hlu;wLLJPg#a}@%bZ9ju-J4Y^Ith?N3f~I>=EeiG9-dU!>w$7+OT$5I-ucDiqB4 zN&&LOcX{xgcRd989 zF_=yPLeWGj&6>`U);4LNY3Em8B8>tx6!u-RPowxfZU#3<-j@nPa^#ldCs zBl7k%izgo6anUWXlMrY=KnlwsKi96plXk=7GBx;%<@sRp0Y3Y-pfJjBM>G^h|0$02 zvas;Q-RQwx%vUIlPu7gZvTBxcxZ|uADMv9> z+aZZud*K?waB{)F$RLOMd;O6pqBL%0EW@Jk6iTv_ueed#dn?>B*Cj2+=%vtnFN@`4 zmT9or&7YTy-j)vf-9Nh+en>-M{h-ZhI&>P1i_LMyflgn!N?Q5*;N?K*IO!`}gvtumc3c;jUxGBkgHs8` zI#&i;-jpj`iZl{7*J{#7(V^}e(=I8ochIB81Yy3Cd;BUHEHXxocRH{1nyJ3r1gu4Z z7&O{g76~oA%QaaV#8U#BLhP4ORHDIr4s7l?R02(gwP0LiPl63M$%tHRU%M>4;E!Y~eUwRV~5rkEiB1&J+LWUTg<%sHPID)TdJLTo%#X;cH& zw!u8FhHymH{^xPcv%;)6B2Hy%r^FST)a7i7fJM|;wG`M2+);sgPv`o&d>$QGvPuLn zb8!pEoq87Nq!6Z3L!YQi@7&&`LO;3Quul3Q@6JcaM4n?N+1$?=oGJZ3>78MlSYX) ziW9H1YmEEDRevLVD^&!CMx^(W4I@DqrrUmAZjFHGTjkqZR-CoslWKQL0uX3R%^wIk zQBKMp`bGAWG@F`Ba-U0d6Cp_?N{>oCqr>2*-M;MJ811Lc8uy`Ijm*)K>DRnGzw~P} z*(h)I&UJ`;67St0ajDWKo&Jg>MUli5JVtFm!|g8=!waayk*bNZxSilKi9Z-Q%e?@p zW4Mp7=wc?{_*Bw1C9fg(o;~6+YU7s}8 z9|1LZD<(00yh6@Jc8lYmn2zM1h^q_CyhF|&mD&fJ0LV^DUOTWd?TNK}jfvY|GpcO{ zw_A?5I}VohAao-0SgyNJu1Vvs#kYd70?tMP7zg8A@*IwkNWe#Lsnu`6!KbyVn28mS9q zk99EsGc!BY_F%zHgLlvkT0voJxGd5v1xljW>GGiS?$h&GgkZ4a=l5JtqdPQ4w1j_D zb>@j<0n3O+etqwt!06UHD9*--MvaeFoq`YtBX;Kr9Mr8-04aZ=*mzb;K;iG{x7kP$ zETnJNimEjL);zVb6zfsx{_AXxcvrzT(tNc?GP}_THd%Tnb_TQ0XH}4DqC~MkY2Y~? z_4coyO=oSeg}bv=){?9iNJWQyU052gcy$HACfIR-vCoe%nep=ZNVo?j7@?!R*;kyL zLb4$bB+r8qfNCCxRlMOBZ_-bLwC7O-rit8Eb{M0WiRdNa!zu*Huu-@Hc#Z?&q?u&; z=CCTP{e^!nFLY5bioZcM6bods?aGjNQ|`#;n3hHkw}adlw4)inXO6<_nDF?=qBIJ^ zwnUT?wqJ?qnE}{{lx@?Ts2Gl{1!@DlB5`Zhqq13M84%1n3E;bqZ_AF*A!Q~*Jc+q} zvP1yyYYJ%_yo)3wQ!^<2CzJpy`k&3LO8p#lfQzAV)~}Vc0xSDe^Gt_y*F5ihUbu?s z3q#k^I0;X-*M?H8=9D!hKt;V>i66#o5f5Zv>gr;!;J1vJaqHu3*a`V|I~u*MDqhJ6 zkHMB#T-Z>Sj(|`3V~yUxWwne)Hu`f|!w=Eft@0*-`KxIE_9gN*+He*=MVZUsB^<6$JqEKZ7joyou*PZa<=aTSc2h7_>N$W9L^6A9Jg>k-WEb1~(_1k;w>ZDIG&O-{=SN+WP1IL$XlXQU zBCW{5HdJZrQZ(u~305+{-7tHwb>{U45DSj5IU(G>72C@-dW5e-O_fs$DG*DqWOmnx zaa5eoB9EaWaKeeGI(5HM;xxu&@+d@R%^;lB#ZtJf3047#4*Spq+;090Qt^U`=q8Dz zQ8LB6*Me8D3ezRE-=h-`k$%j2K}rzF-)5j%J{D~x)|SisklB=~W!o6M8FIRr z&eiBu^cM#fhZLBWTo23RhcdSJj9Fu0(EF1Q+3ErtfZ|$A@Z%cT?M~&g6>Y!!|13T- znDuMuVlg6xz(vYqPr;cH_Lqr)@j;+P6Qb7Gpv@u=1mDX^;c#$X*HJV5`W&ht1<&`X zpu(7L2e0!=c%Alw0`_Xy@$7olcf?S8FkV)HvhPx2h0k23vuMwMD~=tVxh)vkr0_@b zW2cH>#g6whm&;|ggZb{*#h;(8u=45jMj)Ac`$AtD0DB(ap5vC+ZVRJdaxXrO z-~zq1_HDfudHbjwjUhWZyp7X56~xI9!z>AWQ)I+oXWlZ*t&Sz3c_@j zbPd=e+`@}+?aePliAui+LcLCciCvh(D}4>(LJ8cUK{A&LC?gMHGB_=v(Eb$}>P4pI zz7>B>^+dEMu@S}L<;+4d$W$hyjjT9h4HXJ|ki`-27V}j2Y>TfVnXL*JOd zBE0$-siE$<$oD|i^HFuqRKC&3!DzQ4Er)vj6j8>NL&Vqcije9l#y->}nXs%_f}sk3 znyqhjYC3;oAQrzr#by%pdJX%!F};J%2EAwt7Q!(H_5{;N&9mv5uevI>qG!I(JZ979 z9G`3T_*QRUy;!Hl>WGIWLP_eLmM;QsXN@tuMPscKZgp8v;tjx%6Myg`kX%(#&EsjN zau}*aO8FW3TEgu}kiM$^KpR;791lZ&4>g#Lkl88P%8HGT&AqbZvIL6#;oP8H%+m&E z0Q}%mPXC6_;0TzIAvp~R9F4V4M?L?>()QbK_aErJg$yGv=(qQZkPa%t-#Ri(SgJ}^ zQ2!K@K9v*uKm9gGh1e&YZ~g99=q^oeISz=^WBsNW22)H=NFdMYF6kw8be8Mat`>3-;P zbTKQ=$m}FOyLS@!1kC&0T6B$KrD3Ryd_?2?mYhRPkcvIs1fv99+#5J*JG; zI{`*DjY6VO?Bkl{!=pj6G z^9*lM85Vnhc5LTt3EM^yhB)Fk9&q;f?f74V z9NK%Mfpe%fSnGKE4_|(%G}~Pnh#0|w|AOr)DHF^wNVX?_tuK_qyE_aK z3JN3Np4EdnQe_zwlnWSCwnjc6NU%0w)B`UUc2jR%Ld+0rnerns}}J8~aRsisH8ODFN4 zB;nokZDxPEOM^>XIp*ts4+7vIQT@PW@C?}NO~0i^E{Cm4DX_cDwkmGi9)2Yh)y`K73T!o(7cz4|_j`Zi1HVT=@?Y6REDS_Xr&pf0sMnjnODz%)#i zmtHAjMr|{3i764fPnAj>(NISR_^5**Wab{X!q`%LrTkhYr;3puD+*H>vX2DroR6el z8*WDp&!H~DX2ua8vW&)1v;J@e3qRL9n89VRdqrZ_plkET`nJZ2aq#f82oW!vJ{XTY<45m$k z>np|U>v=iOw{+Uj9oToGQ6+*Q{;X45ammQwy#%)mdjGPIlX;=SqX;RlT_ltf zl8A{Y4)z3W2BH!`Tew>^6;Lz^o1+Ij8i*nAe)ulV+h4g75@JWx=9R}ht)XEZg@sbL zOJ-P)Q8~g(ZdM-4;4C@3?xpPM$WK((u9AU{+Fx9B7?p8JG+*=P@rTP~q|Ve!-QTg| zT#0J?lfQ3akA(hc&(1jhJhf8WaCjli^fRv9*DsvH-yFzToJE&6YKJ0kE#bo1`+2_K z7(R|-yo=$WnD?6K%R*WR4}6Yc680kJI69#)HHh&{rRm14J$+MrCl8$Us&zM!bwWRE zW628Ev4?u~u|Y5BU2H!_kgx8cG*rRfei}5RGUFoEf88vV2$JS}A~4;VM~2A{kV-H<<+UWn-% zatZJGE?qboyu|`N$E3;THh0D%wn{z!Vu6fpp~}d0&p#*JX}(lYom5ENb;wEgj`5FG z95o0)#EVLhf#Guq&)_E;PS@W?fOVo!9Ge?NMmOM*e{Su^rSL;t_=|4A2|l%%NNU5q zuO!j!BHuAvzRO}~P6y#`I5{C!D1g0%=)$=ojs>+v+d=GYJ0p4`8WW1uCcNl*8bVacQZEF*KDg@{oo9 zKhP#f2<7FwLSKZR1Fu2f0yH<*WDeO{FW=t z>3-=N2uMJE`gq!)bq0o$|3*qONNdkA2!;N z=-LcHe4*kIK7u+7d$m6E52%V4!u{{$5hYVO4B#V+?#heD2z6r}X$)UjAM{=Gx`n$@ zJVbCz65;FdG(iKsF~M3K1|%*yBS&8t+~dH!!$O(hfj0^Uh~2&YYSPgn7E}&NKYH`a zPj{kxH!!fo2>DnBA4A2GZ$dyczUWh)Y6BW99XDKC_}gT^FFhRIWM7&VBK{KXeOB9S z-o%2>a;N#pwFixo zDz~e$@$E|>#)2m|JLW^%%7eV}X{IU0m9>LI-Z6vw(M&pFu355Ve*sw{dKo>M`eydt zk>%?;6qG?tadouN>Cn_V4iB2o>OcUr^`dbP(dslW`qw#%RVT@iFHcRHCDOe!H*)sh zc@FQqD(0<~j4T|oX+=SUo~gg^IRdI91Dkxi^k9;wX7+?3i)#&Y%QBeV8<~K&2eEyG z!%I+)?=hEB%K&}^$CWxoN24<(0`)osnHj^gt^8GIvQy_R_u$b$XUE;{q;25!L+i9}*TXS8AHwfh*xtJNT%E2u1NlpU?m&>4rXZXv^5^V{2wKlGWc(cY7 zo~Y+HfE;J&F?FZqy+yImzi7KQtzt>Mx{ykUM7ei&Ae zoCIn^4PrXri1cR~>bC2{6XK;}Lg5k{Hf_>KCvq*NsnAIC6>jAm6%5x;uUx0orK$EU z+b&Z+Em2g1fvbv@a+GEK+4ZM|pXQ!FXQnUu8;v&W(&f@6n9)!QGnf4Xf^3vB_A}i; z7!t=LDPlu3SqPFVa&1sH{jNm8p;K4oswt?!PW2$LE)ZK^87h-t)A`bg92XBI(P(b@ z(+2GF=RaG420)Z$e>jftkG-RBcGn$~BXWzhKFu7T4h`E+Yl9Pv~nFmv* z25&-K+}938>}l+MVWqZ!pP=oRD43F5^lfE+9wXL?ZTQdGZ^bI~?PPF4o^mzYFKvsOLhmYN&`x}`Rx9uBPxUPb+(jg^6#9glnmRIB=6)Z z8hR!Oc-OaQr4y9#p9SrGu4L5jP8g zICK~SfU%{!1932zJl%myY7^Gssb$Xg_!&u-acL!*CO2xQU|&C+L354$MO1$?;dgzy z=%ymsOifo8vvi-gdpK&x#fP!?7$BE2Dt>NvhEO zs+q^-PG4Qp$V1@A!%xumhKXymQiYjrzJiO_fCBsqlmKe7jA-LcwHs{iWb-&CCAFOc zyYwp=`Fz}?RNEFFm)ahQL3&jd1ELCfTG{PrI=8V{CLw~jtxfq~;Wv;ulx;LQ)WoY^ zMr~=@fdCdxM={N=`AAco-7h)nD%r|pnN-)C=S#5zUPx2`1PViH8-ilU4 zL1xA`_epUrp48m0_Fk-f>XuC36*|j>YQA}>`oo!=j0CLwvznpEM2Y?76c|MwHe5a4 zU!Pcyvi`8IG4t9?kNKd`Y0khPh{b5;?CMA?n`~rRbAMBBDWkArI)PA$Go>PtZe}o# zM-}9_X>S+E@jV9(5NM`IAgU0jYxH1Ih*-I=MLW7b-d}R7QqofyRCZb3~6v%F0RrubcXG~6k1jxZu@ROF#>zy zQ=tOkSb^zR8MR) z?$d;FJ+`}_p3wyC(qPCfBikDfd)N>yG1rhHs-YqyL-*GN7{2-Plh@ui67@j#*k)<^ z0u`*SMHioT+mPdJ9GV|nJ@0gZ=q5a4{DP(7pEB&LbE~VNRJ3>$wHr0`Y$KW#(kaN_ zmHt`QhtXi}eC}jNh`7vn0?t%(T#P!64N98p7LB9y?TJyQ?d5VDE$D$VB655(TL(nV zwg)Q>MODZ`S#DU(`FBz6QRA5SqX>^RZNKpk6J-1jt^VCCg23+lbOn$;C8+nTREBV2 z(h{~1=17Ld0@#~u+WRFuy~3;5>c=e`qp40OqK#V4Q5tL1Qt!xOrB8V}9FgNZ3b z_HCHyy+Bz!t#m1Wyqr8RA)GZ{a^mH{(QbW9dKMmP$>fHH=ykOu!Ky9IU9_| zG@)y^v?u-H@n^lR9a%H{-)~3ytQ8&k0#($Ts3!SIV(kpqI3gO zcx)>@^xEhjO+g6?12IOXjg5aZ`SPa5)YvUALdK z_x=BZ9*VZQAMWKKIbN6BSDW>H11oB*R&v#;4WkIC_^OH9J}FV@gyPy6yD`SlT zNMu$WAP$#_4OAYiq|@6rK2N)R8>k91keu1sIhTimphiYeV8mrJazc?VcGv$}tJoFZ zQS{=gxS+xJn<*}XaRc`&bx8yj5TmN6XFJdTUPr+1vF>e`<=vF*95wLw!2RU_DxtZ7 zJ(XcDQKFyRUYA?`NR8X5YX#A^?UH7?^}9MBC49*wpW8)D$Y&cRYnwTot`3LKkX2y<_Kg*=Vax54{e*a+*jqL2$d+f^ zcX47BhMhk{xl!^0sM%3fm1xqYUH8)D;Mdj{tx0lKG4*3JgWIQNv-S_r0By z8CYa$I{6DxXtDp<0uaN71|)PxBK08w(*))Fpzg`0AjYbpT_ev@2GXke!CbThrLRC@_)6TZi6V>X!#y->J1 zYE4$vnpz};>>h&OcP*MMpwFTNkrlbbB+now(XNKXHI>b0l6l&r%*?P>jMpG&Yz_-g zN&)t@C`4qtYqwC|aj2MJClVr}oU`{hQAj%heNGiz$X(meq{jNENXuzh^4O$L9Xk1z5(nw=}xN@6kQ zgRQ^ecV5(!-Kp)pl@umos+Mf*DA-vQ8e<}LSEsTXmb1onb<>~fLvo23*V@ZE-4$cM zta5c1+Zq)8vgSo57Do$S-T7|kgkm=d&fskC&&t@c`kFD3bdlQxF(0az|Yvh-bM zESnzB!Z@d9r^a9+7HoLKy+378dZ~aOj*dG11U<9enc89pQ@e8XAdTdr&ss!trc8yD zPX_%xJX08Z;Sc>P)V%o%>Y+#&1^FzmF5X=tVrP{~lMo!8)A9n_`ii?ojG8gr;Kbg- zB@?EF|7Gm~Ya~^}!emBuG&r8Jx)!Ta%Y2sHdMEx#6ud`kS*pRuk_&)-S#h<%vi7(ij5kn$* z@lo~UBQ*n07As~;2U?He%ROds5OepX8lFOsjw5bPo;6zo^%{3m8k9FL)U>>$bzNIZ z*Rp=PF0b{yg>W$ltD5lRWv7iPi-ZoB^|=4~PDtza76J>+J{EI31FL;avFQuCx`g!P z47x`F;5}ETbLRX);h0D>ig?x%M$J*P^kO|r_md|!X230{;+v{00UfCZ4d+v^G$OAP zwF;xmfPMFNXIAugvmYUW5)Do46|=Ld^%yx-s~Y8Oyb{GZ<3sYSEuZAVN*UqQ%%TV^ zQ;T&qDMw<-E9*N_=}$k^8=L9sH1jRk{+VykYW39h)cc0yM17V+|D>pwa7o+R-8f&g zQdkF{sva;=1Hb zaTv{-L$Nn5-48f~1%*M?2D%aIVA{ZSL~bTIJU|y~SZAJvF1%H=$*6499JL#$F&pjI zul7bl-eT;jG(_{pA?CVx>Wj6xc z!%A6YVE|2jO1C??DqWI(xXN)~Ssg>94(t4&1 zMTfejT%cU#r#VVo{3@Kf!2BBPfKNlPc%T>Mi5H6F3+X>PKRFm|?*Z?pNst_>#95MJ z#BF-1opwVRVNRIuQl+7V@54Fro__U@eXM3-R9(UvGyH+3C$T^ziw9?-XL{ul&b2W`;oHFk76wwwYrNi{0d4Oq5)y zGe=gFnaaG%Wf-nXOTJpuh$hG{47N4b3OtQkoC{)r{RBO~+~%f>-}Hme%)WgjDUB`D zZopn6R=9*WHk`L0Zo#4k42}Y-pi$so>*WiW!FH|h=4l8Q@108(LUfk3``yRQS4`y4DLZ`MV%3;R+59YNsM1St0nj&-fkC6mc~KnxiraDDSY#JRhQMG&8hnmaFd9!$&=7{veHV6OXgj2$TBTd%9SUKaB zIpqp?GV9aOz@$^TH+07ijXy!lO*qA-jnu)(C`1J4P!?+^ZZV#R``l0ra4&i)cR0qa zS8yDiz0S(Aa!RB#zsy$np`SngD|Q&+LGxA0FDt&Ts9_WmD|-qLR5QUo#JsU31Bf#}TyIQn5o5WClTth5!OQ z6~{xX^JZG$&tLZx=I@0rW%G^PD$Gnf4m4d!;L@s-olC_me5x`>6`Kj~T6 znxRz}BC$nQ=zR`&`}U7(87(P_ZD6JK?a#&!jzn(YFh?3g^z;~ibOkqNCc_4jNh3F} zKJjRWOTK!{p@C;ht8F-kJp-?11KXdLp{AN0m25x#W82zVzZ0#0Y|}+T2<=c;R?8`A z9=AvdP;^guSx&JTe;ldH1!}_KjrJICg_qlW@(tepTy3$fnc6s;`{=x?I`2Whs47cXyXvrI1 zdhtTU#ekgivrcFmGxTZZ7&*7awR24j2y7}Zc6oAh3>|&qrV|X2{kD~*Ez9iXs(Nah zm8CgZYA@@Z5n6LCr&2%D=`x}6uQ?A7&V6khDHW9+C+IiHFITZ5+7R(0 z{0qzGPI_+pkoz1RWaFJZAK{AN>D8y{=WD5s0<$76y+6&b9nW_YDsO+!L_m@M7%(p4 z{mJPsAOtpWfd7}Ud!<84Wj(YRIgX&gM6*86skIq%Gb7cAg}=wWUr#PyVBc_SRnK$y zaKxP1#ox%4WDvAkL#7=oqh5!HeMYx?>27k4tT&IB%e?J!viPXoX4YKy2Axh?TGPPt z>}6tHYm=@f>hC|(mK-h4Z$1LKw$0HRG3~wCq0Mc}ukhQ&7P{{|r!y~z|8$a(0b*~x zxBTN%kN+=3qdN6($~vB&tz9>{7ftS`%S@H}u1{V>xj7RZ@5&HpvJ(*hho-BJr~Ci@ z*u)rS(>*;o-Cfhnbaz}=UpX}!rkm+$uIB1C&D9rIj_GE)e%EK;-{awL-1~mN&g-0K z9T9L_%Y1D*yG#9`<*7Ec>%BbyIuE7LpD$_~4uQNM%F*VsNKJdz45RT&3TSrzSHHfz zGL>Z7)a^JOV!8PBMs2R$@AiwV*O%P)?@9r%1og5qstNW;Y!R+`1|>jk^ovp5RMX{T zoaTPW+Z2f9WPNkd!bc)Gpk-Z-w9#6#Nulj~Gop%cS;3V56?F_0cW-VGz|!*9&FyXQ z^#8-D6|`>Ck--;b;tswFg|f9cp~b|_A!SpfQzR7RRTIBk^#Ba4(XSzPV$!-+^0xo` zT{tK%Y1#$=gYBEc?d|a@Q!YD>A0HVRD_Oz$8no7B&&*Ygtc^7wb@mjQyM!DUxVoH` z8&}h){ud4{1c)Er+{|Q0vKKM`p9P>Ag=Xx*SvOShzfcb);Tp zN|fsB%$O(m*ZM6tJ{ITfZm(g&72AB-!D8LKkU@I6!urBJ%~1-vJn#cKw%M8WK}++_ zzgjCp-MMIw;fwGx*L$B?TP%bA0bxL7OB~B`N0)({%OBFi= z4&K>QG`SBs>P}gdblavnMV-h#9)BQift@n(v5UKm2%F)7To5yMCmzP1yZzcRI2K3i ztL84JmV^5DM?!Yv0(y*Wy_?KVtJW_^hPaD&DEnh0E0S<1fHVggg@R$3Y$ksmz7kB+ zi!nUf=_MA`RR#XLOXI&Hp+O{(#kqtod){vOH?Y{MunQ} z_*n+I)6_DDt;*k5EJ5lBy?otZX$1;-k_};_{N=kf>$?w>lZ3978G;C%TXDXWOIbt= z%ziT**evmwQRsW{^iO`?W!N3t+Rxy+mJ|C^xat{WAJ|o5QDOT z^vQZ%`;!Gq9w%$PL-E6d5itp%(@BTJ*Ad?RJ#ow&o-{qxg@tv5@^&={UAz@QNZ;Y> z%!te+Bh;kk<4^p)maPZLfG4V#JZM7}Dc=9?r2hb6kEuL(Ms*c= zIigZZoQWf9PprW6K9bXfQ)DwLQJ-mc!^T&2`@dvmP+F`re|<{V-W(X$^bRZ=mZ!G& zCP~p|3$KrS@YwNjvk5IvM(zvay84djl{&93K1p7S{_7wE1BokinL~E;_uFMXy$m@- zUaaX1Z>NEHNxwfDz1?$`>S*65PFEeiN#{=EGwzoA(cIPi8`xZJEYsH9G7K%gyq^Ww_*V~H&EHrD|8IvcIZ^`%*;(uc zC*nSKozhqeh!STFGIki$#p^i%BJNb?P7;a^93BJhq;$s1hTp23@cYwc6&z?4x_x023AajzybSLv z`DsgtKch`~(~DzB9OcCp5CEK5G z6tjTwizUt40(W__uC)Jf0hCt|;XSN|zn=?x-@+tam+SFo2mny&Yw3Aia3icR`Kx?E z=X#i~&XG%h^j-B?biRV=Wq;U+(yd3dhNay@a`t8pfF67GqaN$YD#-|s`Ss5TkBg5c z<7K=%N;Yxs)SK{1Vd+_#wOK^nscOX-$A&hqaIv^&VKFMMoP#9aq{GRI>&1*6S$j26Z1{ zQTj`)jJdFyyh(VEI2`R?Gf}MhT$e~Bd(0pppd(Yjg$BW(A+Bi8V8mU&$;ADAsXJ~G zpFPe|L;pN>(-1de{I`3SG>mPRRBa23&;2kXsswB2FTxf6cQpcL@+EgeTE>J_zTbM< z0G$H>;o-ctL@E*w^dXX8jK3-J>gngDj&h zqxi#!`s8{-%4RMp5I3iSm8C92F6G!6hTpj<;(F^GJ{476SpJLKH zudteSzJIJ3dDJ-^Pl%nq^3%Z3t?^?TN2WF!)GnpmP_u86iNJS=0?;Xlt;^^6uk5 zE$*;%{my__{L%Y|n+g^iOHzxTI9Kk5_2>dAF-Z*`n+Me(#h0VFxAU+p2DfvY=4<#IyxDvvo2(&8b(iPG#Xg_XJmmmkwfnY37(NNGnI&q# z5Vr-ru!~`zwQz%;@vTA8MhavMgPWzpOyo1B;9c_nQ?N2i^YeDXo(C_7BkMiuLiaY} zZNyg2Irpw}bo!5s5wc25Go?IhA@>Lvx|7y!Q6D-mcicfm|2UuebQH|es(HU?hX+d6 z&*iV!t)~#84A7Wrl=;(t%dT(YnXWB0jOM2gjzqVB2!(XZB92=Snj8|#v4LY#FM=l` zcuI8=jPA#*Q^%HJ3yvyZLqWMPItzHVOF^;2Yw&Zti|4xA4+gHtUe@2<)!c+4&{RSmcgWC$Q((ZS0 zI?NF&tFe54=ownhwWk36F~Fr|B$jZws767HS|Kze z;t>VN6fRFNlTSJg<^!Yn&TXd}F9gf0eYKThqr>w#S%R-L|AJfoD_J=>c*IGrE~kGK zd=Hzp_6HfW>wf%tL0HmxXy7yMr773MW3({X>ljLPzbk9koJ>tOfJlUI8Y)R!j|iDe zGEl~tDsBV8#5yYcX_3}td=W1Iw8+(Iyp_vIqo9aqe+X7BwcG2uOg}}db6f2?$mZm>3d@l1OL{Aw?{@=YZGl!b%8Ak14FuH*Z&4H(!<%G za5Ekxl<6&pHJ8;+6F^D!&#q6vY3#!>qiLG@DHt(*SEGvq?u{nJf?#V@V}`JNm`zv=!_fhOU$zmNTbrk~bsN=>j)w$$lj-N&NrRgfg2+ zk1vnFPi8I?D0om>APPo@=%M=>0`!wlo3Kk*T4S6br7#_9HSE6&VNi7X9xPhay_rs+~QE-O2Y*9w3*Yf}; z$9uniNcY{?oN)cn^mpZXgs#|iv+F^|aQ~zdA^?$Ysq1ZWRI1LVff`%JSEVnSgq44r>ha>&lKYg3nUOQ>MN#WW0rVyE=sNB>2msj2 zkhYCX17dm&jBu`h{GBc6#!%)OtTCU=X3$Mm*M6*N#CeK@(vOGc?*%>559peD`7rXe z_yxm|(l>k-+U1P}H1{vO#dXTX0??;jVf~SWkaI>fJ*3QzT-nu2%cfI#-GriFVy>Hl z+}X5K9qp2=OKXdFO?fe|KGr7F^S7i}lsv7e6nWQ|XRPK~)B#%SGZ?zB0!SSb`AJz# zQP8S!?4G|2qrH&fF)o>Ea9W|{0n`D8hCjCGif!F2h0yOQbSC~lM@I|RZu7sti7f^K zfj`fAle9;DEq$6te{{&BC~rBPl41P~K5~_DEsX2pE$?k=L=2=u%4zI;ODZZP0F4wc zMM2sz-TF=IYJnkV6#= zVR#_+VGSCE|GTo)&7XRj*L|3&;TYxdb<|5ABIKMoIb>MPTbl||)!KiJecTxJ>B_je z1SlBt5D%eFdFXD3wk~{NFAO;Q;KoF@D&cQSSHr-eu5)pIK7Tx3<%1pZR=nybI;i(! z%e5o*WR)ca3qKd{l1%p1FUI}0BfrXnxoV4ej5V*ZPv0^1`W0Q;jVkeT=LPh?4NwjG zR2Z~>b36_&P5&Z7f|YC@?g*clFt^whW!v`Xz;KZvb*t6S&)lWysJ4I{Fx%tY{HE2X zrnmB&bmS?KZE&1V01fY0Az3IVT9gv$DLbN4*+oAmtrCd5+}pcED0vF$e&Bz*0P zfJ5pLU7`A*#`;ez#0M*VyM3zHA^%sHo~w@-sB8qdvCHfZ_N!DI zWNvkGpM!!YxT(K5zB!36v6k;pVBur0MBHa|$U$=(_u&o>S^hUW@2w)T8tD*{iM`^W z3Ahg;=WpNW$VW#v1P4g>nzuvm#CpTyr(_9%{>&9f9Xvdwt6urm0Tc-sO4jBANG8k@ zrlFS8SL|t}ZGtq%m|0#%!#NLZ&r+L*XMWo{!CU(_9i`%a(fr|h@`AL{NH{6(-^gE^ z&#_(z%z#Ix24@hDw&ho9;rWxo&ON5;17=%8*L&%(3Fe50!e`FCi=kzC;JZKCiUD=I z#SlAo-ZO$90Y_A}-3lVr3+wO&RHb7f0g46><1^8`y>D*+(cb%N!Y~w@NE7(a5MN&~+<-G%_=P^vTeY`eA+tMc2| zldS=40o)3a8|QEAk_JA}P3Vb~%W8aty^W|vhA2l&X7yXrFJ}$U8zFER>`dTAP2aPz zm&P?jYT0#0Y$_h|o+(zT>~!YYXIz7u}P}t^Zdzosr+@m)A3w;DB`wW9jJ=lk?TCmso_!RZ#adRcef&KiUUAtsD8MyNJ}ZUupo%vC6G}P zy<#(_V!F}gzagO0oF)*D?d#|lroR~;_P{WK=gp}rMv+x{L$$ zN;4F1LOaV+o^{sKosyowEwyMr-z4%o`#_sdY zo4XJBP=y`)4?@ehhAewWFx63(FzZ01VzGat5_K?!tC}5S0E4Sl?i-Tp&NX((7>_Z+Us7Ac*Bp-)Xa0) z%146iW)@;=k~ay{i4g)*J6P~8VMosQxp#WsF73aMc>n`l*Sft`&h&x(p@_c_LV;>G zTfX}xUa)X)QYqc`ffw;Uy+SF=>*6sg>(sl$96oL={xpf?Ix1xMu0H=X`Le$P3H3hA@TQEf0*o1Vq&n~#G(1G?R72r-@z0#gmV z9IOn`#6^l$lW+abt|bQZEw)=`8oe5Eje81ee7$^cwl zhq5)>js$uQ`F;CRkFp2G&+LfGYs-gWAP}5%qxjx$YyWfga8}BOR}q3N$Q%%2v}d>B z1AUJO6I`=S)_D!0+Z&ft7yf+;T0&k&L5yn>Mzn#FB(Rm(^RqEgCnNnTyQAf{?{t7f zGSBEv-dCo}O#^8ZVm#VWU2?TBWT%AKD6}K~!_Nx9DEt$+@SuX=6C~T;^=Dl_uUYvJFUj;XOf1Rzesb~2w zTo?!p)UPR182s6oqm5P|pV|{AE}Re_v~x%zP_}}& z;c$B&qtDufWGBBD+R~MfY(D&!!!nALH9ndhT%E%XT=jp*tv`;1jlzCdP86epdy6%_LIySZA)ofXv25Y#iwt3-Kqv|-)Rw2*#0)lA z46pwwEUdpEDMgv6s81$U{$`%vU**up5o)8GMcc=(}}e>^mA=UJuDL7-J}V{5XM=O;&FaY_ZFy z2AFP^>Ud5%&@_}7X{DgF@hq5+ao0-G`LT>P*7^C|Au-?>fh~OH+0g_;U3vjEwS|d`ShF!dw<~{Z9DmX3 zZxOR3+S(;{Q(eoqoK|bok9JU|Jc6cW>YCm>W0Zx7RXR>#Y|JhjgsNg-5>y6guN3|V zC#G_usMOS&u^mqpsV#7iiM(n4wQ`nqIdXOHb4r|)?-M_0RpRy>#Z0=?HP%C?qSpr^ zYTZJpXh*7r?rCPIe^IOw88weh23xrFTYt3iI82Os`FBBsAT7q) z6L0ObZP&8x^Bm*ea9kCbv~ZmsuS~y!1Z3lByUa`G8)-Xw+*)6tWOwhZmu08w0r#)o`{fp7F`!6|=FUB^79D+oNCPZkEHMhZZP z+r~-`K@npeFSU3sur(Cpyw+fxAzd8w`N6H}Dvk?qvpr2xoiFofumeoYfm+Uob7(nw zYEBBj5KNJf5!zjP)WerV4cPLk_gb5WE{#O2ilA8ydHV8qe+spdhpc!rm9Q#4NnxLr z%l@^xP-RzSZ+gO&I8a!>pqO!g!(v-N&}7Shk1II{P&9_MA_&S!H;813XLApaqeRTyQKHBZMB=iklmA(I&l;10 zbRRFmpr)HSK=z^8JZoSdERjibP}lrxw$jLoqL;FznEm9TT7!&)y#AwH9|g@yeEqyb z4PlGTSb6>9hDH5j5-lIsgp>630nB#_XK(UuH<_Vac==O}l#B9S?A9(_AN$i50Xl5t z`hl-yF$hm3LeBO{*YPheqK`@dO}?Jo$fQDX3FlFcay5#dvGTVZ(O+=WH|8mabZf$d zlB*+{`EzY(Nr%~^w%Ce{S}vAbNBcg?3iEjN;S4B#R&H7QYdf~*%kKn%G+6_F0*uZY z{?W|ub)JuxRilO0$y+X_X{k}oD!3byF>uc8%5KiRF%q=s%p5A+ORGJE*o_E+wV*XZ zxKSw%>fbDNzZKbu6da!NE!6sgqb?zHT~mzv+t+6o3c$O#i8En{ZI|`vEOf`{pF3=X z0J*QoP5<=|-Y!JwZJ)FCzR*8mnKmw$^mtPBuE6IEZGKO1eg8|;9F|C&VwCrrqGg=R0YxxsC+=6uT1-rDZvKnOY;{qV_XQ{?$_;|$dPZ4yT@ z_^zz4B8DIYsHr5st;(1HtAD0VYDoEXdFgsR*KLZd-bKWIFHEV40}&jQ2O<+$EfIpeH~&CbE8tb>QRYT-}ue}-w zf&tt{x8@0yx2E~(@pcanIn0%`Ch}`}KXEKXvFC8mnswOFexoy9Xh6>)K+^yWG-`m# zaoPE<{e9=VJ#My_UTu~^I&T+NK`Q|S zd7!@B05}5&!74WtJG}1lDM-j6VkDV%5EneYY4SmRO|CWZ`i+;D*U|sb;vt0MfxOHx zMD{9;7f>hS3KBPRja4k19_(t6q+@&IqIX#it60Csd}%mTJ*b@yO^2u+{iuH1k`pDa zIGMA9vqF<=?6y`$K{BRqmOD@Bvr}i~IBzKoF?>h6tTNftLpe7Ig*2oPq=&_td6hT}lyfUMsn0jZLn$_i6_d=qhf-E~CoQ05I zAdWmB_bgAtJoCYKznc+nU+Xoih1`M0rGvq}Fx z-23+y#^+h0zF%BOB68_mTGJBUN8x8@C^YK5eKRA5=ilbOkFsH;f79%PxSj8%?Gn zo)Bt(Q!i;?UGHFQC9hUiP0Kh2|NW*2xII~%xmYu($X{7~AW8uX{uzcLOdoFz;bH9( zd~g1b_CmryxK)QaI>gaJ9PL6 zgOwG9IPqY)ZNc$Lb*Fs*Tg8f~46T2S7nn$^;RXXy?;U$e(IIq!D&2C!)`TF9y%5|# zp=^wND~s{fi<=-A`!Mujy>c#a(^A}bGs9>cjoQ<3#Q{*aVQrJ9HIXJI;9h$si1T!b zjL%`1+V%VKWnb3Jy=hy|gz8-mQoejOVN5X6)^<`HXt~Ax-9-8+QYYB7!kVLmp~9vlc_do>EWFaDe=jZr7&t@yzFzy zJgnx>F{w*;PklKJErhjz(51d@Mv=6Iou>f8kJuLDdkH>(Wj4&gX8=}mS&*q7+KX(p zfC)+=f2c7j+g#PLq!KQI^53n&n@sS99vhvc+ zl{8gDD5cN9jnyTBn6-EkR1+|`{ph#`gSyWfwDK1-N53mGBF>LU*9=0rL^DxxJvOo7x(ii>uZc2CGv#9n~?WrI#z^FHXQ6d>z}n>dOz4zBOTp6vJ=2 zL@HOAl_@#a1w~?5P5#lDAIMQ$gcz~%BEg3_EaVXg;DPZ;te_Zsjmh5nxr=wx53cX$dmp{Xm<6*qSCdECJz-xM+L&=w@J<=; z-jXuA)h#4=mK})Zn%CXp=jWvR^_wzP2l1HMb*;i{B0%U4g-GV#lgHXi*``gj zNt&RqmxX~`OIFQbiuq^2^<*U?KNiEZaIkUaL!agGi4}U6?&{xuR!aaCz1C*ztoR!3I3A0> z(>kD>A)NPZ6$KRq^#T?3l^sjU zXJuwW_GT*=uFtkIiCc?`7z58AIId%RmI?!BuP4LV`TIvMInj+v{HoUBq}=g88k(+A z%d@B1jqup6Uk~&Vq((Kat;OruGXqP(dLF-u4uacJ(+8CozCC1gSAmSD(;(|QNf(oc zrQn)Ad&!376dmUvlpvJbKonw$E`;;qyt6aViq6a;X78i>PXp%OkHUd=9ytfd2DX{; z5H1007K6HXY#WLIpbq_-OAt`TG;_)${X=bvcJls*B$zFgeNFgwuHU%0UEUk4i^Fm; zG7{+A+k+oQ!YPmhl_ZbmqF$fow8h)7{@<%0{q7;3E^~9)N%9=53oHOuzy+Y;pEpB08zwG;JXCsEHg91B0-=g%&6oc5~p#5xt=FjN(F` z!lY`qfVf^egE(VsS5s96$5V|)S2<)wl*tn}L+$DX{u~RwDKnEmITN1xla@CV3rYa4xd56}ArL{BrTet(=Egd<0g;6c5 zPA%YO%8?egu19GC721ribU-`E+d>@M3qn*8)1&fF2def1!neA(06`lkr^*=^431bH z_O7rS6^43gY=Im|%coHgmyY%EM$qWIsr>e5l+EEUSy19Xi@+I+tl zYXrfQpFMLKfKk}zd~Du6pN1H%BJ{{IvKk(mNN#`sz^WHmv}N2aEeE&5;RXPLN*ban zS|s#BtJZP9LEKk&o>8c0b*P(? zjmK#4k_(tu5scb#B+0VzSAz7`5@}f`F;d{3KRhmcnt%4?Qm}dW^I0@F;SR6Ohp}Da zE(8i3+xVvOavfD+?nm(tX6b*!UkC$B*enT3UirWE9gxrI^v^}Q=(Mr1K7cL>KZ)w!BKIJ!8?7*@=xv7q3didS?yXF(-3tarU4*VQk)2y4nJ5v{TZes>Uz)VdpmR@oy-6~^E-M@cO zbr!pPJDis}@L!1=h49FJahONuNFk`@V#|^f;9@+^k}hWRXfkd&rCWaoo>Dd9hlpRa zQv3Ux4~day)Nk(5YNseKR14W9>>4eb7_ryWb9qBsi@6-%wI93YNaP?XljJ%H?7J%b z%H`ZCNo3;oFzq5$g}cp$y$_X#5S@-hzKMJP+QNYMP6=Wv&26NF+%Q-A@^^!-%11j*D5#WKC`_r^^>ec~=d&iB zibc3o3+sOConIj6W|(dc4kPW4_eTo|ZPU>&`t19Cn}dNk+WOY~Z$kPAQcP?)7t>^g z5)(hzT&yqSQwz)2CXhXL;`Q`rIiebFV-35KO~nYEhJ|0L7emIx)n|U54y}*%Nu$d1 zNB*0r;0gtS6CX~Qm0J|)=QC7(ebP7fl!vIyKHIavTdg(l&mP&FPHD{?-I%qXCd-T* zruObWK5%y9R$CirhRettxUzkBvqFh20-dnz9n&w`;-5Ow0nps$m9{R>AJ)Icz+{QK$T{-Yg zaQ?P)L}rM;PtL|vaxANGd878HWKO+z29xU_q%QjMhU=+8b?ZPoySPhzYeDO0aP-w8 zPoIN=t6J4P+JejVU(GZXnLL7`F1RaUi)Xp6{WbtMJVri74T0BaMI03?6kY<1KI^AJ0;(j(uJ^Y0xSM0a1M-oJ;toCx0@ zkdj{2|K9nQ-LdR+-DcuH_p4_X0DWjQxoX=FFzrf{lHT{*7$RtKSr0+PSKr4W<}$>R z(1DBOMVf7vP)sy?v4I!r+>R=tGFvZXNTlLB061X9BDGq!*{HL&!MCz(>vvY} z0rHyO>_@1@+8391vmrz@oO|PF3tCdTWd%WX*ZMv2?B)#q_kyc4xa+|&&NHW;q?_#O zzv1*J?c&2Mq`JkF?=Pd*N%hq^IG=YF>DZ}W^wp7lQAq_bQPpEZgsC*s0a7m{Pzc?*`kC)lr|gYal+DA#PpHS7;Hn89OtMekhR zQ_1R$j4H?~UoyozUl`6M}1)++0`BwEz~Mt)Jg zVV17{qrqgU6&xO0EnAkMMOO!&t)aKoV5TZEXSXzledZ7TLf`^DUPJ=PxUEG4(bTad zFzdE@@Dv}!YA*$SNs*xW0cK`h6yslRBn+=5zaX`~-)>d=1YQ<{YQ>1UR%?(r{*ib6 z&F~X-%r0BtbOU`BnkVC5|B-U~+4kxJnvLPRNi&vSS#b$*@2g`oeiLkU?m)Acy(QQ7 zib;N;RevSoJ9-s+ufoW{V2Fy?`__=jl)>xuMA`8+@I)*ORpuL+6tmW$)6LN)p>4#R z|K1khTjwK-=Fyzv5nz6+=-e8C(J{TKEsTANExP@cDPnL?!SkR9{dh~Tzh7{2ZFoHC z8_Sdxa=Apd&?^<)k41S&p!!8a$N7;jZTvjNe%0qu-!%;E0S&$N*5K%^GGjDtng}fI zijvo{UL~wQ&ll79-Z6^m?-nrQO;s3&^#y5mk`BInH7eN84{|8njOCw%lH}pK)w(%$ za*A;t?xK!ur}D*%lobtX2b%dAF8FmkeM9}8@5F`7(Tu+$R~_dqEgoeas2+LrHY9Pq zeH;&xWcAA!d(>W7^4I`2tesdbPfuG?P*8yV82v6Ch4E8J1Vl=Ot4y)Giho76#~@+} zR}hk6qtMHzeOhSF9an1+(2#w(uisYvzskrvTckxD=V~RHKI6oo~XKBB{#-?6F zs^Y(KJQ4OiPjeR!&pSaD>&>WjLdsfxP`;y*H&D}r{u-^$eHXh?P_|4!%=|+MOQ~K& z7-$Eq^lWTjmSbSt*RcpDEU~NFgJZRP#-liM^?mq|(u?|c=Cw-y64a1X`NLkTp&@T4 zhm5w-APXB!yN5)Z;&Yc(6u?~~Mni_f`2-qK`@XjfyeTw2B!Xeg9X9^mhD!d~8dD+3 zrEl*wv|d28A~LZ*iv0R`;{(N+JA5S2D({N2Va0mr(BJMWPP-%i)F1!fmwDpBo@fsI zPRs7f*lWi{b>GL<-$zQ3IP_B3R>?(mAolHd#dW@!9`zNBUY0 z2D{Z=Qmo%}d!>fdKP=r6F#lf7ef67roVpDm6T8G}x4Gxf9ksTTDEwUXatp~uMX$<< zZ9_^Vam#wWg-8`C2kI|t?)(!>H}5c~VLkKON(Y>|^kZr|H1O^kVJf*Pf3$o&UkNbR zM45}2_fDsz;02U6jCs!L$+FwMLlyvXzwlDL1fWi*2x`t6N4SIl0xFMjI{7CGm-^@w zSyq*Ac4A|z-iIMpBg@l%Kk=^D<81V`!+ls}$&OJY-#AP$t?}f{g$$~t%S+uLa#s5N@!tlC#jY4-*{8!Z4T%0{OQFYG=Km=F7v4MEc*UT{IPZl z-|{e_3A<4X-~APvD~{ulrs-ZLp1Jes4-OZk^xRA7slAmdzNY?7&g8tx)M=WGjAiabSKvHdfH95#^j~5=!M>`j zD7rl=;*WP{j%~bsD?6!YnJaK1)qdTVFMDSJ(jhIAd4WTOf1h(=P3=lht{>mz%and( zH-NE=ZAh(FeClrZxyh%fP@emd?pgOcbUyN8@Ik`V-`jjL=-9QX?gH{4#B7z_^SUd_ zCb^7jwp2Awu@INmk$vG^>W(jtIfW@!=hCSDZuyty)|81)vs4WD-kgl=aJKnL>h4)R z=(K5nUu6{YA}S{NOU&YloG@QRS?`aW=%>ZxkoC;zBITSwh>3rfB~?R7$7Mfn_s$b* zb^>L|wZajp3`Hb4?tw-!7>9323uZsCkv|jQj?hcb&Z<5v#T5QKi3NlC4LoF1-FNrZ zIoKNychJP;%kmJjhLXiFK#YBR9RbGI{aa#n3rC+#18&3vy7YzXmI?tuX;icPKV)52 zdYv&%GaB61B&c#ca#c{{ziMJyaz!!I*SHFx&n%S5{B4rm&q60+dd0N*uJe;RhL61RdCw*>G!S7^-5^8I9W_FY3dpTC3_xzovcuoIJw>bf*y{PyVhEqGyP6X*S@J3X2#<|qQK4v->J&({j zAW@2g+#1X$(4ptvp;bD7m7lAnmbD8#1Gi8)c!xrDT53h{_}5x;-G$ z{BO#s3gvi$*&KhIX7|3wum25_;ngwvxuKUi3W{r=M-4MW_sxx$n@P z&gi(cT__WF9zpxeIW#819=VK)XrsqdDINuhZ?z$JRIEydlx6D1=b*AFUUf&Z4d#w% zi4`o^oxDf-z(-prx>TWJo|Xoz4j6a5JfG*yEOJ_boKbz6Co=IKjBe1@%j`7iK?*_y zKa_uQ^au-8DmZe&hY!`#tARo`-YJ zo_*fE_gd>+>$>eE$jr^}sOgh1<%yJ!0iCLQiI8H<^TI z&?)PhhVa?*yoxZQmd?`N>v`|0if0Ippltewga<;QI{&(;mJ;=VP)~4Du~aw$1J*G5 z*xo7?7_)}8be@kj2m2ckx}rXQ?nhm>#_*sZhi||FKSpD3J>PyMxH9#F7rep4&ug5N zsh*XQ2|6v~&mzvO&myFP{ONvP?}0;>Cs{fFY=DSfFax#BtorVblq|xU%H6)mSxryz zeYXTwOIBQpxRuW8{8=Zy_$;OaY11Pnjp|UflQBjhr}t^d^W;>~&-5O(Q0cIAg^QDK#m>4s75Xs}!%3ApHF$RVlkSoGR>{`zuVb_Rk3~s!+s`@B?bz`NN^yBt%3te});oAN@U+7oyof zF!X9x5FX~fFAyJab@ciB`a!O=h;6_XnxA7tvAP;%)D{6u<;&E;@HB*>qON1VG=~fZ z%J|yLF_b(i8!60N+P^9-!tX#^LI?94@0DUTV`7~Oc}cd@DhLKM61Hu|CvBrlOU)6=%5 zfgdx8rwun43D_vteex9e3lMNEHnF)bB^JYbQPS>I-@9Ux#2M_#&e32|lRo?kc5 ztG4bjz8YJI;`tR`-b-syx3?XPB0pKp}g381I%gmnWk0O8PHd^u6w7$O*?u) zwTEfAX{5xgd@m6y{ z^?OR~=`FwRgb!z|emk^S$6#^Egp7^0;yFT_BsV=3H;gGY0yy70)wbEDS^r>FL}SYD zxTk1)=>{z9tF`LxY*NuH2x4cPXRh4<|Nl@={R%*bA>WUciJwClh%B{cy!Nu0e--!cL*%k^rXF+@ z^qke?EY!K*rSh*(v3hA|`SenY$H}St4E{-SWccm62!?ND(}+`~;U;$*FYDBTOoEV~ zvp2LR@v1vU>X*qW>Q*p`NYY+@qN(xJbcH5n_J)GdJ^0&dIA8ub{VQBa9Oo+2cI5VG zyz&wO!;h;1lv)rRo#X7c+F#?Wdc}LlT4_dQAzWz)7<_ntq_}#XaYI#CFS3P&gUK|J ztGHM{yhr<^K3$hcX5zbgmz-IxxKS(j3ZOzXtL5MO*|&5Rbht@aPAokd^KucO?-mtr zwcMfirt)s(_FD2i8)`4Bk`j`o4#`OB^1dI|Ssb9;_h>5V{^4PH(z)tY#n&W_)z$Dy z??D4Dx0CNXVZE>FoZI||$_t^c0c*zNX!)`{4L^jCHEG%FTOn$GS&Ej7y`>tUK56rX zdk!95JFr-fsrI8mMmS~{O*ZR$QCW-FijR~^IyykIjjr9I#) zWbX_|5!rRHQoc%^MiTxh61{w}GBrbu9p6>YU9#U;V+N2goZagNR?x$X|EpEY?=bMI zT}K;v&UD|dA@P!v8Bxt;N_?a5m2khuDfUZLgB^)(e-50j?nlD$0oEANvPX2O(O$p(3 zblxxt^EwqB4M=aFF=wCS%w2ap23TK%+syapr7)qC`Ei3d4n>&YeRw0wU+$W)EEEN# z`mTjRDk@(b?AJwImT3i;|HxVb#7B)kue1b+;se-_ibXXEemu&wPIYla`=I|GRf^hx zsqR;^Bz&hYlsr>U3FRP|&apgcG;=eRz76?)UBIeWl?}rtEeX0`4pI&6)hzNbXtS(UDx}sxp*85E5P1+zzh7S!R5NY<8D)Bd@BxQ z^wnheCqYU(6(l4i%TnZa&N=XKpR~=FhV9@NIC+8|_$OkcMCg1dPURme(6KKbw8u?| zmj1kG;YXkF`hz-|B0y>5&0rVk4Uv_;99VJ^ znTdf4IH;^-MW3G$B?`m7-Y>>wNA>t^{Ww8T*ajG%%twVT(={ZcV`Bw!ViA@;h9o}TL!?HS`1HOQ z9yTrLJzKBlN53qG^1nF2V3Wbf)dV=F4|>1sc?q~dS$3E;n>{li__RC2&P}TZHP=-x zDoVv)quU{@&?eq4+ad6-t%ce_UJnaE^SLJ9^3X9^wKoGVEmgYoRw`1O^o9KU{;=sa zgnIP06jfi1)D&tId%|7xrx;zE;#xiig<|+2@|3$e}y1O+Y_M32Zr_cubIayUfbO&s#xwYbjL;uu#9s9iL1il{J_EqvBn zkug#Fw(8*pf(_)#_0n_p-OV3&RfB3Dr+O`Q3ZiIdzB6~}X7qVw%a4T5! zZ9(v}x2wWcriLgSMN@+8Kh!fk(c3niG2PVV`wNoPHJBg9!3|`OfMH z!NL9$Oki5FzQXX(JB!%%ONqM3y*`D$9DTeORhM6Aayq*)94Fo=qENI0I8Bk8gY~POyDS;To^s47=7oxQ-P)ft^;K) z5W81F2_$MM`ivjVueUjDLjHg0(&NBynqvu&X^oI`XrH1d)Ec1vn~S#>E4$82qFw70 zF{~;sFSE6*Ls4T+()wb@tgsFY5$!M}tn{-m<}~GLERFeM=A5cHn0h8xkz&c>{yd+I z^>FV$e~2Jt^7bn0v?P~Cv%oIRQdpT3%wk+>M=y0_WW9YedOmXA#xd%9?sqYA1K_hH z`Uu#dItCw5s;i>|Lx416N_-k4|5c%2T4LkRzd&8H!A+3vla-kHoFWNk$%xdIN}VF0 zA&!%i$R&`Iiy5gEpTmt97tDI=23!xuV%xFt_-Jm7fIr^Itiq2QhkmJo70h5dX5CNg zY4*^pV>;M=E-M^1Ypiek?Jm`9615Gsfy%@7GPyX;j%l0UGLdb6zgE{vBi2n2J02Uc z1wdyAc!4opfUrJKY`=n9AGMT~S)z9~Ijf>qukb@d(@{{#Nl>$soAY~7E#&=27$RtT z(RU{>lP)>Bh_|`0>qpB?>-9$JSq7Q+qXS&T^Fo{Pa(gK78{pMRK{7@BGTR%i9hc|oDDTDQkYjc-=5-ViCSabX_8 zLR14no-gN+3^hHD3?4oKYs~?=`YS~1N|KG{qaDv0>AkX09n-K6}JgQ_6|oDHwgSB7Gsj> zyb@;Z2x)PLOpuZo0PH;1^HjtdV8F?b!j6v1Y7>)!WhNoq4X)5zU3cL3!^Yi``hhD2 zh@FC>)^qa@rfCBxJCcJ?My;M4@oqC}#3-H3QxAK2JpF^{{z4L**N{)UL5P&dBatnFdUWbsiuKT! zxKH|s1{iH;Isgs|$(16G?|+|U?(P^`Y82-dND)L(=NRdL%nFr-y_fn^V;5qcKDN@I z7th%#Ef$QV?$kmv)roi82vbn&VFk?PmXFs)a;?#$N{i2r&%Wca7FlD*ZBI1!Q+3Lg zw`n4SGJkzFci|3DSJ+Cmi6G@`$7Dv-i5JQimng(*3eEX?JMfS+om^ydeJdJ*DH8f> zS&XefK8kc}ocU{m0m35YCS=3P{Mi^Rr;qghd;%COUX!Ugjzszq3pam%is@xHG%VTx z`Awzs1k5y#u@_F}qjylj6|~<>;jY{SQ3F7C3l++9N8QZv`G@)N7=_V!(N0f@RI%sN z!Sruxa%CgM*@!Jo>9;;^!U2c5LN`7@f1osCCi(;;r@!x;*-N03emz|n;Qn7p*T>O7 z>hKkfD@*k^klag4e!HIHa(&yko>{jifAopo zr#k5b3A>jL2oPH_7Y71#-310Zez8ufVJ8+WP9;(!9e5xDsDR+asbh3vn_Ew|9+9MV zpEvZ?_4Qj2%{&3u;as0HS#vx5E>=mxE`+{@rHKFEs`a5sr;iC|UZMuMwkBCFl5eG% z$zXy+@171fkX&c`(=xm5s&?_*pUXB5x-ZNEjSx-iSMTfAO;aJa4iby5!=0@Kkp6n5 z`!PK!Z_M$r(%P&>ljQBAVBrOb1sn^cp_`I=8i`R3#2tW(X(rHX zV}la1WiUwQ=r1p zaeck|;G#bhq9r&z+y1?`3eD)&b|QFO9te!=1G0IgmxaZknYnom-$RCZF2j&52Z|O& zv;c_{5q&ic1|X*VM9a{+8NHh61w?h!)Ou6NeGMo5eyo@@R?iPyt1H5nRr6a{r$T$_ zJVwwLiVBOxvPi4`Qy6Beb&9PNyP!PQVS4&qanb+jd#7z4!p&0$|G6L;Qh7Np|AWPx z#nq7o_3cCJ3woom6$ar0GNVD~F65mlVbyZw)1y(|z$_3`@2-&oaxhyj>P@}kAw4QSWAr+U;k}JGwwOlvq+VY(2g;lS{k8v1F&A{O()rWQlz~lG4h0>EN@p zMZB`wjW7UTtQ%n#KQO9}0@-(Ip>`(gYN3u|p%zdvj8ieh`cx&ajs#@ypo|*v3U5}o zB}X1$0hrn~tI5|gBJ>QM`EX&=V;5QH1(wV>a(%t(bL6cuPg?+9l7VibOX+6mF>hz^+)QQ*oKDS?hF_>F&nHxB=q5YqW(1I-YZaQ{FLf6K7L5wN_Ul;l zT)xb{?5@E4xQkfZVv$>iWd%-BZkhp0*+&t5rsjHvEbaE6UCFBhC*@G-=S;5JB<8l6 zoHZz_;;Ufs2&gKX=ana@c_DUoX7NT3QY+7-QTzzP;Vx#$VsG_%tgOrA(IKzXvkudb z-+Opih|;+3_Vhl|iW6R30N%o=Qd6;<12RDkE)%VCNo0R;RpL>0@%PWtqJw@e1|TQ% zZ^!wbpc0*ZUH~FU-HP^=4>!427+=PO^@Tnioo=>q}AG0uuTG+wRv^s z5kYxsNPVCJq;jdlwvZi8U7{#OOK1k0l`-F{hmf_^xP9S2^!tGyKNqEN5#WUTb{ zydW|OO;sVf`n4%)7Y=hS6{rvjBXMH`!tQ!l2nR{ zf0_gOSIj{kBb%%?-Z7EKRB5#u%)GAWnq{nq2v{oCx@dgN+C+ik+8FuFlt_nycCP{y z!NC2!wdCdWc0)o?&hELgjKqrXpbk?xGXs;gog<&{>QC$m5?9y9)-S)evy-QT<8fJt zoHLMFJjVnD*ghNQ?uSZiKgKHNZ&rGGz|{!S@W{@Kwq&m$$%0RVbO}J$w!AooV;rm+ zz;FB}_mjSHUHj{y5zJdmoL3idk+DsyWifk4)iIC=PmtTxre4#_&r8pBF# zRPL=s>Ff^T-R|Wgzl@uhGj|jzelgXfF?d(^RewsGf0mE7DxH=#Pq@?g{r80TB5@6# zJ?s+;lX`)KlE)Dtz=qd3ZUID!5jgi=`g+~ZVZ1yb@D)k6vk@oIZfhiHVvaOhbq)6 z2VPN&I1*79wAdA%CbP^!tnK0at+en{dC}V%-w%hDRoQ%|wq=_1Y5Hr$Ej6rMZj%)eacQ#$-wq26G#`bn9}W~d*lX9_a? zhr5NtGZLS)nw&hVj(1$Irg6ppJ9wCwh3?chh6jc1es~vdqFngRviGy!(r1lGGBJ4V zfMnnYjPN4zSm=0_pf2pMuh?RkOMYi==+lo=UhErgy!g92ulr%m4x`mONF%9=F%<&6 z|No2C)koAWaxtt2qp=1BBZ~kcTw1aeO_2EO?5~3yzqIjHb`6jYg7b+S*VL!ab?kGh zXKIwn_Z1dRZ3&$A54iVBEy?u?CY&r2ew*Yy#CFtBPEZh0G$wq17+&dg2JiHj*gA+&o4(d&*z0V#g73?dX*dO zgnWdb{F`@L*4`;m=9MtkF2D}-TtkX@is>w4^OeeVJSU}Wg3ESnWFRd8O_&Pom!Cwb04DcTMjUpyT2BygsEPcHQkLt{Yd6k7!;oj07G(ObGrbYqmAoz*F#+)k|;HRFJYa@+FSjV_c}009zC zPENP=zo?Bin0=kAkISpBtqv)~fe2_`CnDqm4Fg7r>PCiwL0Yt#M#eFOlJUt9j@J{X z-j26o>2LjLiP5^2v^zZDLV zIa_>{(n_`QQZ=mPW2!G{RqG;7* z?JZ+;ruA92i!B+lE{7cCOGHc`3paT@_EWJGwrxJ9)i1OJOA?qAh1}H7FxMr6aEmPs zZI)Q|I&Ijo03r-Zey{^9Zz*C}4%ukxR6vla;|SUZQ-=b_;nm5XQEfruO|C1R<0wyA zlWY43#nBZa-ct~@Zru0C83~Bvk@Io^=ij%Efhx3d!%1n3BZ{oin9vX(UX|tj;X_^L@Vm#8gC6 zj3sBnPJ`JY?#sD;4dz0A_`wbu2?5~yH1>}k)3U$iud{V zI*jBXTjN^G6ly15*yL^^j$x>iO5^A$?V3VnU704Xt(PrCYiLp}NqH6;OK|x0=A74V z5wzIO!*Jx;!E6*Hzp+$jnhNhcJj0!o@GdIWHjby&orL6Z636SY%?Ye+DnsK2!e9@| z&#l6($j5-B=IxNn5QJ^@>4M7THGOaDzsGKHW!NRlrZFsu< z{<{e`pmeHLp1~H$YcGEGAjdkW2U2bLjWFsLRUHwBJ^Nx6GFPc@pZ# z)oRu$-OG=W``h~mGTAuQcdb0x*H7v3PpAEC_cCtVuL<>;BKrQ}_T!?@ZHunu2vkn9 zqkWm;m0|lbKGrEJ#tRIaNMsxs*mJPv^T78910=fT>l5@({Ap2z$VK~Wd;oZGoj|{@ zK~VN;5rC8RP24LNXUD+mSgsE`9$`(@nI*GYnY$29TMX4okHl?i^Z!c)9s?l`J^p#fcaSas>%*!~@}F z#e33{sqJ5S%IDn}ooz459zVb4+XEwo9W;M^NaIHB!u{bTeq2ku?5li zj{SAPalaHxq549%eprl9gUrN<*S{s9`&So=pcVNYpktTq#?yX`JUyx`)jw8@m}JXZ z(aq>8c~@d+s=BCmxG8w2k*QcXcUI}UMS`?x73m(X893s6M+g0?PhHgm>T^aKfBW+E zYj0Z*@d=MMhcV`a6TD}6Ld#kn&zwz0&_)XUw*Y~|Gvq>CQe3}EF|&NooAT!LvVY$z z?D}^lCFWcguQ^Wcqg`|DO8<_OjfcfxXpyPNM@(Bx4o_%rfCPB)Q+vs*tj;ydX$PJz z6T6@IJuh1#-8O3ZAKVaE@R|#j{-Ef7=kA#ymf|lQbZzH_X{;cmR55l$?3$7GDGT^7 zDTbej36a~6Z4YLU*O z-!N`+m-2YS?OEoRFsljN&Yt=v9pGspr^!m67_r$?40o$y$X0dE+g-z{VLwsbn?lI7 zL9tFO&qfjR(Oq0r`QWYERQZS`2tj_##+qq)-v&o@;#&{8B5m>*%oi!Ko7}Uhes`9w z>>xwJ`%w7bQ4+kIah=g$#3W=yAb;;&eec}|$?^$O4+*^%-Bp$v|6;y#nhbN_OZBRG z&fm&=i?}4IJ6;tWBgC#1pOWwdtRd$!fNIA>%T1vu{|keRuTzr%2c;@W8tuWTIl?m+ z?a4YHmm_kltM(0AP}|Wc@2U$rn%dJr)zfe+xp0eSV&lHci|Q4wGV!i|?NemMUE~LA zqC9y;Ab=8fspsct8Us=xrTu`ckm<2k)SA7DX+wZab(_r$)oO3bH#;>qWXQUYe6<)2oYpHCm1^Gq90T@1iTIU#N84bF})*1xmFn!MWPIn ztzkALJi?xzeZtL}^@8&|PRoPO3Ca?AT2f0Sxs^Lgxp&k}p~8CS%V3Ts4u&tR&o}+! za9A@={r&vgs|WSpT+(Bokn>r5)Xk@4b(RF|IJ&=J^K+dL*1xt@03%8>2OUjdX<|r| zNijedoosrjeaz8Znd=@DUS)X)otYYv2pe(BN4Q) z%+5}>S@C5hXmHcNMp|;QK;-<2S4eV5J_77fht=}efHSKTp~UTOGXIa<pi{_~;rp+dN_TGH}CWFj*r6&3qJH(ixCcoT#90^k;>({RgX)mzE$zb}9t>nE2VcI`RwZ<0LSGJAD@^I~! zf{t3``Z5q$ z$^2*cZk!*WLL|e8kV4FqE6p((u^_>PI`{LoLVCzyo;1%o2EzmD0afC;u6R+r`!beE z@l6s-QJIC6N*b23i6hw<>)IS>JL4J^Hq4a9$~Cx!K6V=>lOkg#1C{9O2|^O0?}o^`aGjs}-a)2!N4NkcYnQ8-3#eHTghsdf8? z7TtN)+1WW3Dd<&U92>X$hi+&@+bsF;OaZ_gBaSonXGl4%LXCgfI~y7Mf9uqkdVm4d zsxzR+lbvlXssz+ycKxYpCI#J8-1Cw(U(2Bc{d!|f_-YOLWI{Wx$pdg#gbb4X0hD{3 z6~eGMC}iy2uFbakykWh88>>_P_nRE2@UKh5chJLokCsup83}m*T0T~EEVQ>bxvxb4 z*an&<0J(Wfi1Cz}+61WuL|I^aeI~sfxrol0-Jx{em;kGBYh0h1*QdER+x#wEV7=d~ z^)*09k~yQCQ#%~SI#T&e2eJ-w#K|o4zN>f@?Z>XCy^IDIBMyCihF(()z#EYD=cR!EO?=U-N78$cHO%j_`(|F(*Usy7Lxr^)B5W+ z92}fhj2ANQE-rjgyl^-isrp)i8!G~_wZF$V9GY(*dtTwzyt(Js`&$Fa=$z$;7r<_5 zzDXKji;)_F-+rNH$}EH&as<}`EzSCvRU4zr`6by0I59>Ji`&~Nvo^^tzd`G%ZAZ&W zFB^HF6~AyTo#^TbuhV0qV_#KMHPE9lV$w%hDhV;2wwnF5{pF+D8euuwfQG_g`*mt% zfff@nrU51ka6@jc33Fp;@D^PfYt((*yoz;JbYHsSh!XVXCYeQ%O$OLua4x30V&-G1Kj1sNmEOJ=YZ1ks@?*-{zS?q9V)ggw z5D)k8d%G$n%VU@mNuvJ6_RmP1yek&aU9Pwv?nU7h}KQ+R=m?hU%S z#yWaF830k+-Flb92OlK}qy#3Nb_CsCH0Ic3{4lOCKLV?-pC&j5Pu4FCg@L2oWb%C# zQs)E?4}nSzb3;*#WrY#rZGS2?0xdrkMyrQKE!R@#GBN5Zbfx|RXgZL4fD=;jaN}Io zaBK-sVN99gGJ|&B#XD6C!6AWfKfZJx-`*T~UO8W8(nQAteiSL765t>h;Y>Rn@P8O6 z*ZJx4Pdox7YgZq6ll9H1wCbJK)grZw?GEwnT6s9$UsvaTsZ1 zKeawG$88R_^E#MjQpH)kwUL+wVQA3^?xfU)82t7;9J{j$tP4JMwFHyxCnn>n7+Yqb z;Rx&y?%2B6%#;)ZH8j5&S72e)>Kkm|@e%2Qvfa23ZOR(c`CugOzipbEA(g{dx7)u2 z$Nnr7M7O12k}?%p=HdOTu1e#5(K5^IWiG8`$0`PJmllnF;iD0@WHq45y1=1^mXye7p)<(nb-)5x#Af0*eEM=g(p${LAo za2d<*n~t3CjF~;OGWQ3(Ti_cNeoSezAR5c`isW5tj1^NK9bS(?kA1GxuPVmdDYy9?+vF$`eb{eOokrS+<#f}Xk`rDL@e*>{zvYCAxL8ae zxi#)AS({X**sA>KsmnC4ngTJY-@WW>EOmdA*2vumwfEm{tT^+ZDq)hv`oLKIa1<&l z$u)j|GEA&|bR8~O&61VCt6^(Kebnr<9r@Vz{UM}|>|fK2HyktvQ?K$Feutai(V>cx zJUu9&^hZms&VFFsO~s` z$X1p=xY*bT_H(3;MxjI8UI1i-YjEUvXFe?N3%2Fy2AaFI4jfpY{`Y zbmO(eqL`~!%-bpa9pgC<(Jd_a&kkERGsq~LbjGHcn(o@(a;;eMt(^q;HG95v2H)Yk zTNc?igOYnw5Y=5P5IdKyh#Pyig}J^mNe#+G2d$7INin%316#%FZ5)}x33hlSaCs@> z=xn7v_!xMH#%Y=`XUQ$xpJ)^F?K=IH` zG5^ulTpIQHqz4E7c)im2c0l>Po5}fR6=eS;KAlvox*78`9HOa{?t67z`O*DV$%U8PPzhR3vOqX{Pc zR-NB+B>jQQ=`ND6{AzY1NpqS%3D<`oHsabEB`R1&s;TwmLa+}ZUdK9$+N{=ozLFca z5(-!I-1#yCOhAN~sX~>4EUN)$EU)gVqZK+XVEf(ckqN3-+|9xn)m4jG|10T>7j9-C z_%XXm^vS2wf#lX6Bq~sCs7O`(Mbz)1>0dLru5Q*L?CI(FdFAfz4%k>VcTn_l9n(Sh z6t`W&f=dQz0#`_G{nq^;VIpwzv#V9-B)U$QqB@~PUcLH5DNd7-j>nyl z1~Kf7xuT$o+45H+m$d*idV5cFma}Ybl@K+}4*za{*R0Qti*DMOE=fCT%L?zhT2 zc!&O%PyV1CsE{6OMm>vXA&q(kA{M%mIV3e4QkEI2Y^r49wu`miO$?>;1}YNFKW+9L zLYTIjqK9S3YLcO~5T-6xLb{ImB}!{sOgDY#)^>&}==hZ^8rQYKv`xfSeVyW%6kTJz z)I&?^FBXiG43c9$r@?+uw!--ji0YilfItg#37^7!U>s8p6UPGayFtxw8kUHe8|W%w z!9$<$%z+IS4I`>cA@sI!VN0r_H6-?i$oi6!W~6a6^4`Vdq2oc|gfu!e|BTXcQn zmyGdd;@FC)53rB=hc#u0wl|JujNG^Ij}$XlChJX5?Er)l2f{p*T+ zL+?O=!PkYxE$mc@iLNA^sO|&xTZF&v|6BE%cq6mCazY0vr1G{d>y_K2e7c=&$`n-s zRTT5Ar)G^Z@7Rc)gA3_1xf`_X7;W|Un_PLn zMhNm+Lkw0LO|?uZvl=8T_F2uTaDB}+NIjlsqOlYZlR;nS0B`a)m;U_HS8Fo;Fd);r z#?qaUZnL=Qa>8aS7#G#>Ya%stiln9%2XkC{eDNO%%u&_QA z(qgsKQrB{f-(`m2Kg#Tgcz%?n^N`}!7K{+2?9UveFlMP!4%hdYb0zr4u?UBuPJF35 zPo3r6zBtf*;Vv+l__; zsgwpC5JzJ%gk+xd@2B+L4;1g>MN?l!^5urh&5*hp!aO|QWCnr8J48dVci~&4;Uhk>< zha8Fp2FU~py-ae}cSKRCqIyB z*U;N3b*I$PGpfkH(1V!#m4RyaF(lx%l6G+N3k2H#@3NsnjgPHed0H^fa}{JViuzTl zA>S?+oW)8j2{<&e3O@Y5=S46?l7TJ)wg@||dBe6N-jCvGQksTwmr&S_HT2j*9vf`9G(l|5&R9A~aGO4E|Sjc?*3} zggHb|(dTo~#@yW8M6dziFp;S?e(isE9#sKk5it+-nLW+Vrwe*xdvhKi7k|>7ELxSn zeQEiBb?;y8)B1wcmANm60~KPP-3NDRS1~|Gmw~J8=FavZD*wB2t0Gb-=SOdm8@RZ; zAlBk(I{>1aPnk~!%(e{-SfBmBdy^zZWc zfGYH58|4>K?6*2Km%zIJ2-^Sm$9a5I6=rXQ;{mXu+W+7C%_Gc&s8=ikP?nh2 zH}w90Q^@}Yy@dV^LmXd}X-#^=f(wb~3+KWtO1+o(|EZ<_0qaTd>En`h<*lP0$BrLj z5$W{bot+^S6&0>wB_Yx{X|1D{1I~bcNi2{{~gzPM+_N~wy%9Fhn8=TDFpWqzR&lojqWR+1A&*{pSHh45fbA_ zaTv_hQFlxJaBDSrFi#%OipR`^kegu3D);T4l1HN#3Ob;lkX|RkEzOV=FtY=8VmNj?#NoA z;;u#A_fpDq&5KkTi~v2Fpy1`FE2?W#$N!3{ z%DShi!b!r}L*~18>-J@@h^(TZT~T|DiIq51RM_`<`1-E)3?9gXtO%KlxLm(J@C6)7 zGZXEIAq-jS?i#Azl#>g3biZ_TahY+%{wKLthj?PE85d1@R=&cva&nq-oXf^V-K39O zW=h#ocj-l>dL7JB*rB{n1>c`0S z|9Js)=Ix974w`Vp)^&JaS1xuH@Ezxp-#@k71vB&o{j4L~F<{T%(asIhf|jOeo#YP)Bk{kwM`8SYMiw&qRE!=rOU>p73I)_;EbYwpu`0 zUvqQw1D*sDw-yO76ysW0CVFDAqSiZXAY**^eM-~E8L$}k-2 zK@ovWG77TWx|T^eq!Gp0A=z^4e=N21O23Bu_J3@>2FIPdC%N46Ms6~67*MSGHM5n- zBMwAQN4wui&8qb5lOZt=GRI!em)CKE%0=jakfbXn8-OxSK1s1Vr(f_#Y=5(kCaqNq zVZ(coK+fZ}Hj25vRqFO93p+bEofVPx;N@4EZ|09BWuY9Ka z3#`Hv8pvxCh?S%i+mqwlI3dHGn<%;we1iHLSLMjAN{zYa_l=(hp=JB7JhmR5?;l5v zi@*%{?;ql&&LneB&h<9$T<-3aJf2yP9gg)+LD!2hoU6rlSJ&6AG^Is26FsNV)(3QR zq)+^Qdg@`m9;vK}Z1~M5*B6Fwy123Z`HWW9eUAq~zp2dk4lEBl;hjT*H@I}G#3Uqj z2+NB!914hJ+o~D<=#uwYh;N4A>{opvgf-kxuvTg5s z=|t93e^poBpo=~P>bJU$&ITe-_xrzu9cw(SV#zOk{SXP{i2nQF!qF(U_C2?clz}d+ zx|1r)5CuO4lh$TL!cdb=|_a6$lh4PM}bvK#`!q3x(nZTHKxD?oNRgEp9=I7mB;PySqbh z*I-|G-}~I>{{8;XOwQSJa%P{k*4}G7vE>qOnYy)DVxSP|{vPU>AV0RVzgiQ8DZ|<}Md9|0Y zyGAI_$KxJhO8l5f7La%8nlsPVpF?>{aR!R?#O+ucoe!*v3aq@z_i`Wc-=r)#V;d1z zw@-WIzCqyIqbpWK_9?dLZRN8LT2&yv6WOi-b>Pj8DEZ0ob8#7;G~GucQ@?>t+L6wW zkIi9~d`B z=gqQf1a1oJ^A70Zy4m-6W#2CtcWWS?dE^%P3K9J902Oh)gm+4<9xqpNJ>un}J}Uaf zLmq8dI=F)+O81ug4d$->S&r>Ne#K`v%oiX%1)_)L(dtCOA+mfJtoa;mU zuy?*ah`YPExOhQX4Kp)u_bp$P<;sj$io>Vd=edLW718>_n_s=;KV;67nc28_REY}t zlsC6*&*yBBJeNQOoY6gz0)iE;B2NQiA4cA`l$MAp&ovVSnFbQSiF76-ibu-8Z|A*h zVVpg`6$dcawwqTUc6e<>#qS8%?l$@Fyq_O$mi2?{_5`wMJXtftcLq}`(Z7!f>v)BE zQW!k6JL2wBa&0v&+uR#19TNuMsxpbbBPFA9Ac`M9wSAniQ3uB(%4|``P zmWyt#ZB}?`NE0cE?s{2W>OCy{d^g%Ci7vW{u%lILQ!fd-Z+*VMjU2{=?lxNsE5>{^IIWO_2)r<=Pigga(&vCZwn9 z3S_GS1fegC9zDUdTBp4`rGwEaP4(4OkS`{dJ(Gg(8o4f zNc@>3G_aldHM(Cq`0QHC7JA^ffPZq|IqC1xD;nAo1m0MDq~)8k`KlG1E^x3qL5iXHg)DN8y2uaQFhQ%qAkzt0XRqZ}6OVfp>L z!-fn(sNlri+FJ2i+MX9)l-Ul*<)TNP0ru&u92DCd!Mg*|X>3e@m#yXTXh5o!)huK6 z#QVh4W(Z zQNfD2;{O?8J>y}E;ln80sWxp;(d4f~bqn$g#T4mz|5s5dsgMH+(kaA&TcCi;ko}CG z5CxB+S{)w(G88>bAo_>(&lnzs6#GzZ4F+t=wHQQmQ51c#Apm48fwW9{1YR$WZYJ}z zCjiG_DTkBZ(;P$AFn9;}u4Vx-54=ps6K#IsV*Jh|q9_>6brfYLe>cPQ9{xP;Wjqza z(~&uUkULU1Y1iopRQQJfe#Et96OGePqks z$Y)*)YF57%Aa$(nhT*BAhy?kT4sHhe#+(iDPH}DO9{~~|1eC1^rfk*r0~r#(f1S!V z=J&X3iym8Hq|~_l2W9Rb>%T}XI$thaGE~U8neaX7jWwVSp==aHEh(O36hh_4MeB@) zP;NftyuxKpMp)H#Nd1;`w?>WfqKdZ!DnP!YnjD)i-til7KF}8I^|mFV+o;4x6uk}; zjI$CL$=BsS_A0VHRR^Wu26l<*j%r8hHYVM2Hp#0Ey`6qOE5fbY8^Y{d3+L>nxqf%z zR6K>gkB&?P{6 zw~V}otLy2tL4E&Zj?K8#eI zXvv537|#*V!*xZr_HV}Ve&MSr6VQt{Ng#)K47CbTARXklc8G2)cNeWju`Qp9(8pwS9t1?U=Mdf0G4Hy`OkDjEc>7xr z??lofYkADa!^1*V1 zJyOUyhDTrD%|}>GWTs(g1C3Uqo5tEmG@kd>S{-oSY_m(ac4w=SrTFrK5>KKncEGJF<1@jm{)Se#AQovDrh5yia5X2dqPV#@se%D<9t;v?uNR@EtH<@`uSE1*H5G# z(k<6W4GMXEsJnP69xtuf1SX`DJ1F6pNdp>R6x|qdva}W&wUS=yl%YYLU}h z9Ve=^tRESbvfuI!0|CNzZSum&zQ*a13W1To{4uyPX|Ml?^1Y{6LixkLv5BqUsbdGg zdF0hu6M~61OVRD2=T0GxPDlHK^>OoL`+z=j$kDF*$mnguO*~irfzU1~Gi;pG&89@k zKpeuOK3XhOy}JRmr3cWq+tt7~wY(`{TgpQWa39ZI+zOj`pt_U|z6HFxLS$dCjP$zm z*dN`Z>;;|6_!I!AFiXPTF-{><*fhEP~+89D5vWT)r~XxL%MR(~bF?<5ZUSUK(7h=H}r}GIsLMw^7gL zzSm~HI|-{}Bo^jrO4Noyuk^>y@Th*g->HbC5Z>7B1Y433K{zR6Z0%Y)^Zie_rJ!OA z>$@~5Ua-jw7f3lY5)ZXgBbz#Cg2dTR@aFnz;YgrX=dRNpS;6T;=pBMnm6j*?jgQ#Y0h ztira|6bc9dUM>w*Wgda*GC!!C)I;SK78e$^fe6+JmQIxEZs(W`{#m#W&xgzCReImz z5guX{UcUxsU3^T$4WQ0MlZbnJjsGeVydQ3|$Lp;Ps_e1fih z-adqKFx&z(Fx4OvC~4llUVL(v9}CB?e}lKyg&?D(X##>2majyP;iQNL>kEW&ye@2@<;c6ZEn$$krED`ILQTd)zlGRV|Ia zrZQr8bW1nGrb}@dBoQ&0w_|Iw;;gM>Q$Nc8L)&UxeYD|Au=Uy`-#)Ay#WgWXTY(!o zb_wS;56H4LS9u*80bhZWygijd;q{{6A&F7&L2*m&*bUs~yBvxstZ5U@WQJi;M?!=9 zlLN1y&&&=qV3za3>GUGAV`GSqv|-o@dUdwheSbkPIxUJ$Xp&VF_}}ef>aMnUKNEyU z1ju36=yxg-3gJ18eIHIc9jRZ!SjdoH1YGCOQJ+QrKh{fXNroC zE`0HU05D(A-A_fYw?40sNZhShiB&KBuYk(-l4lC6X#L zT?wjz>Kp3_yOcxZl?q^u9E$lqrumhba<)(i4i8^?puRw3FD@GaQajsS#nHz+WdJ9GMQ;8d zQlR4V0LYUXpi0a~rPu=WxXm)C8=@|L&(4P=aLHXH*_jpBhmcOzh&~wpW9ijrl%fEN z;8`kPFoHs;ARmcRMTc8lxjrar0M(zn-|Tc8Yval%+$7idwIoM#>zN0l8Xe9B6{7tkE(pKZ&9)>@8ECaB&+g;!XN#ZaIhxB5qoV9H4^AsyYNx^*Vp_(g9FQiA_9heI z)1kc!k4RQNDNi3K{pwnnxmQf8BBJ#x-|4_DrT=zHtv3n6U(jMdyMJC0S;>h@*{lty zHWtJtFk0X$U?Y=Vc8EB;j#Lo{@d$_1U@H*xe91=Lj66|INQ>ZAIdA6g$@yK;0;|Ss zrqN?!7=Xq8w#x9fVu89ne(z3}BQKSZq!SBy?)P->cf6E^z6DawxCRi1Fr4CPzf122 zZsJAHweUY)PTu!0s1uoEEL5!*pk~y34wpt4rCIf!T|NW#l&40phEV#buhvAZ#fuZS zO`s!Y9OXFG$%DDJWqejDm2e-mTomtR75L1Hv?yH8b1)s*Xa7}FVPEmsX$_G5)KwLu zS&NHP?y;5RW%nss6f$AQg?w{>+j=-~8c8=G5~g{Y>^a-t0@vgPPQUhk3P?wDZsv z)SrUPZD2}~WIXd`v`3w zT^NAo@E;t4-v7%L8MVC8M4AiG@Fm=J>`T)mO2Etf?j`tzW9s$0zEUYAuHMt%lxJUO z1jAY5gE;eIa$}gK!oD8cz@6ZxUNdyy<*c)t&h}W%PpX5mzLB*FYtE47@JJUG3&f55 z6v$4B)YJv1bKUmDp$TmNvP$MIVzzDn15VW%Pqrv}g^Qf(ZB^{l4y7>L_|==7=r$;u z%>DtrLE0T|e0fwU&s!W+&okQ3`|!Exv|VnR-#yK!ir89_Y-C~V-U{8;noqw@ARN-N z#%)#{3EStJa%GIocyx+V@=4fEKD0odrKkEy{t#2Me2V6{I#P@--W-u-4cp$&dDp>| ztKU70cZcliRysfeE1HDNdW%T(m1&xj8Ije}k8X40NLE@^gGt`Tr#6jP7Tho>4v?PY zM0s8XBSkn^$!n`|+fu9YEb^A3|83>Du-?`{yHWXShRg2WP6J*GQghq+qR_H!rgIbH zaZZ``7dEe3e}6ocW8#{8&!VV)uIzHTu8Ba$hxes#kY~T--;^*;OATXv*4ia_@6bmF zL%KtE7TP5=%6w4T3SM?!zF`jJ{W=)_+kRIed_J^oNH_#?m}R?>13Kyp7UP8)j0Sd z$H~)%X-qjYS5(#?M-nlTTUbMjBQ!&OCL;$YZUQsEa;D@yk@jH8eMU5^6{XgB-1VUN zKHDXA_J40lKstaRc`MA>?!#_=JZ?jxTtXc82jH&x|8zmi=eM zW6Js_gzU5t8hhXQzMsMt{$byGyqDGH`)(^LaI@sNd6gG<)Z*P{RYqvFSPPHU{}~$^ zD&O}~@O2IMiCAAo0{7qIYmbGLtvclLNrSHDp>-#{L!3Wc(;*0+ig#W$Bi7^^bTYo`N;W)z2 z(1HS$HS^hdcSky}?|ZQiOk8{-3dN=ldy6!->Q0vqJgE1v6Q1w9AK!(I`Md zSD>RMZcEjO2VR$DYR!Ld$WQlpPes33=I!7J_ZW|4X!q9hUZ}D-YmeT+_L66u`4Cu& zIlaL1-;iZ0w-FzEx!(VlqGAF*tf2?erz5l;)hrHp-z=*MpZ!$qV*c9f{BOC^=2PB( zXKICj1xGtXenz>a+4qooH*4W0jtE{1@j)90HcfG!t8Ce|#5<~o|94Hk%%v?mN|+3I zHfIn*2@5@&Gm7dC({x#|$mfGju3i629b7M}jDp8X_`HVrzhl#1R@Rq?I5(O#+uTij zwEa&|5|F$J)gkjvdf6W1!Z<5AlAg=|Su(7sJRh;(QK2=K2~+;hYk|j8In#MY{(IHz z|MP6x9MOYf-j8`zjMe`QA3dt3+4BG22U=Yco(fFw*nZ=s>tOCQVs|rLjBnS2*fTUu zTTIR@*DzeY=1Xkfc~E?e*XQ+loVGMger|2$obqpOJX-FY~cH(0DuxzTjDlbZOi zcBSDYHfTU)R3&83x#JYG^C6tZ?M}0=n4E~*_ccIcKgarat8sr!pJj(-f@SQDNE}d;WXTKQh+YBa;mr;+i|SG$NvtZ!J(N#$z|ab$@O8L*wE3ThH(}xL#jxp6*ZB zYFB+c=UrF52cOs$TdiuBT@3Ub5A_36ugRB&0z1!52KBy*ytmr&4X&!ZuDFD|oeJt^ zMk*zV*euoQr>ngzMw=`YH^1?HeGiaDHnuoNUGfq`_BQX2GL2Dgl7V%JB*_ReqDNZ} z`>?-R?*)N`PPNu2TJY^qT?lJT9`~fS_0igWH!I|=Ikbu&snkvI;k0v*%JUHa+HVCk zMxjZU|2OuW(lTa3wG2RE+F{utb8FQB_j=l7ma`*}yS?ocx>+Jv71nKXv-f&Vs!XVW z9OQXBWJn8B>|=nv1=P5eigOB-ZjBb$@kMk7drY~0C@y9Uaiil5+~D1i|FEnRiEYHo zjVQ(iPC#0YvZAU&kS_m`^0BFxWiG+DfbG=P2U9-F=;A#4Ij6Jb6CYRiP4ZRs>m|q0 zNZcDR=xV5Bx2)K7Ldv&+(wVOV#&)LvY68t}rc(+V08u2*ig37r!_`iFbP9Vpu?p{B zz+|I0$6CEFSh8N`Z)=CNg?#UU79?Wd365E!M*Pl<4E{G45gr1r$4lpMGAVADXR-L? zBAhK?9Np?jm&*d7VpP%x&fBx83VDMaGQ6)bsYp_Z{D4oVA>9o0io$nWnTb5ot2=)` zE*QEbIA-uWnrbxVEO?&DH!!CBwokdw{ z`Fj#usMyZmC+P1D+4vISA{t)PK+^`e8INAZ=GvFGNnIU_?h3GF$#Yv@%bO3{bG6_; z{z_h2DQEL6`N;OGvTexIbpJPTyBVo~%AdcaW1p<0H^TY}((iUrd#&~s=FY-4s~^0# zeoLY0#9ogmPjL3Gr8yPSvJNKBe2X=JtVu5yXVicfvrt?3(G}-6v6$V;h;sq7J_kTHl;yV{Hyl;hXSA&S?6kmj4W`3r?7}$iS_QpMzTd|4>m7~IqUiXWu z9|jXXFrxzK z1@C66d(^<9nQ_G8u5Qh`%J9o$LKRho6#3SJFjQH7YQa|E_jkHDuGm!x*`2ndONsWF z>iRxjDFKGZ7J_=iDTFvN?3Zow=Tdrt&I@)X-24|ZL%{%3T$jc3 zv1{uqC^_j>RPzVWsrrXPu{2NF{bLSYI^Rio-cNLld~WcGQ*}&)-yZfaxot+Q&mFxo z5cUs3{!#svmO3*|s(GVeq?3p(E$-&vfuQ$#=_Mh!Ak18Kf@B+YkvgM8;3lTk4l!vB znVvhv$@o=THajFkHKf)xZP<#Lqq~;A)8Ev7Y-Sdr1D)*Z@6#IxHOwGHveOhn)+_*P zj!WS4ZExp_;M48q_5%>2_o4nW#t)hIev*oc0Iv>7Z#r*nF^E>l~lQ?Mi|HfXw=bV`YE+X}? zWZf0PBkQmouZC7~YC;}?oIdc6FB4PCT8JNEvia5mla*f$$U3qbBCR zm9k9aH}5e9;4k0r;Prb&|AZNAm8$+jdhVwu!s2H*eO>+AnM?f_4Nc8%Qd>IQxPXQ* z$;hHxZlMJy|NS3lh2Sr5l$QVe2;0|`#r)B39KPiQ-%l^U@OT;}EKV^{C%3JS#lOND zYWp-4gj@M(m+F|B9Xs%P-nGLVzL{%cHF2Tt;&jR>9#Mi$)1{M@%@QeYUpC<+?-ZSd zNHQT)r(|~fk8$Hl4Wna|Qk7gV3m&WzL1w+nNviK3rdRA9an{rRO<6c@94KI$^>OTq zQe@5xE%T1T>JX$B*x&#`a3w8lsm*Z&?|gh zZg(ChNyn<0&g-fxaXDP@`=i^hSUXHymb1@)jJ@%&M?AhmtU=V_@~JPmthzUr-AsK- z4+o`NoBwtv@X3zE0C!3*qQ$|-EZyDo&}dT8dMIJIfh9KO>LqLEsN4uzeSW$LA>ati zJ_vZWn9Pq)wXNq*`v*>t;U{=p&?U&!QiVz!)u#}uH<&sOE%2j7W;mLJQBdGAEKE)k z)S#lyu`8BX4*GXK?iZCtlCFm5n+AdW*+TLlx%&4{ zriJ96QQuR=sk;gl1A*g}MFNDBp+8IUvtuO>wepDrsmwVqWcKnvtEn?eeQb zN>j%}44vabmu_TbQ!NVIkl}f|x%0NIpyBIzlp1kzf8<_cPi(75`m5FXtd>r0zEgYa zp9usz5^J<`&xd9ri?iw8N@y*hG!y_#ETc+;O@97Lt#hJQwC8YPYk&q*LeBiOPi|38 zt<8{o!UY|isUy6VXrxUv6gz^%9tn&Pxv3vTU9FOma*jFC7Z`Iik`-UjpsfhXdaE^F zmCLG>yL$uopAE(M{_SU^bW*Zi>JfXqvJ%o&N?{>IiFg@?n$XbPX2G5Hclh~5-;S_g zfe_)i>0o4|odfL&4i%rgs!6l@$JOht5Tl@yrXJVHtSH!5#8j@A%YL4%Sf+rHFTs; zH&%Z6F-#9vRIic{q{(L%e@t2LAZVVm9??AIw%jRp8&k;4Eith}93yTy{eioDVB8;6 zC~_}2mNc&B%~eGa#Q?nPllm1O=`>|m&iR+kEr!WUMuP<78a4tvA0cb!=xBMJr^4Q9 zc1azY%c1DCpjd;QU6NeODJx3|jz?WrG0@I8w2YM-bbKG?p?sk$zz!VzdPxRwk2S;>>_ zN>Zf3;I@#gw76k(!q=vhE!8cTo~`1u>^$4;);3W}guv8G$$M}HNYV6OqZ2YX$+uD! zJR%?D%~}tQ7+0llu`qgLqnakz<-6HWXU?fXfTv-Apj8%RyO2Y%#*jW=Swr_BH1?EM z4KyZ`pj&WG@SoE&*N=nBI)K=*?Bo>k#6m(&nw222s7K46Zf6E~GkiE7Lf{90%xw41FE~KTk-9p# z!@6wRk3&QG!*O}mfFLd0$%J?L{o{YK5W4tKbdczbx)eoR16%)L z7;P+}vhIW9NNF|j>Y|a`u|ll5gN|J2B2F=vtrQs(`@&cUuuZqD<+uVB^YAkYnQVh6 zk~iE7Ivj42#*GG^&<}y(HBy2$4|V!~F}|<=`EfQD@(vf%Q9>ooPdThl8|eFRmWBO_ zm?hks9blI5w~b_>FWK6v;mX z_VfS)zH@9iK@by8XxJ1@JnaS#2jJ$&@rtvgUcnHhnumS^JB;uqk95(R)`6c%zuo<~ zfg2kn%~JlArUEF?Q+C>++|#c1PF+WVfx|hG%S_A(3xw3)VktSbS2^$l}JIY*Q3>NjT8J8urBtqYv1g#+_2sJrx$;92T z9auG;jdeSEwVfG@6PZ4*;L<5R27$gy==>p@exCG~mMqf=o5BFj4aZ%jxRV#zaoJmy zYGGMn6K1Mic8l@ukw6CK6X!KQnB_Ujpcl#sfYhuh(nM#_JG0Wo|J+jucb3RKYV)A3 zA_6MC*>>&Bk=d@&2qHaj*|$2y|9(M22b4OP!u(@B)($dg0|?UE%~F?`9Y6QvQ8DCd z>$vad;vl<<=ekX*bI<7F|H8EkxBdw;ZnW=E>oDT7%c$#I&C-n|$BAHY{|muPbD=2Z zyo&8z$aWW2%ao%=h_pmkC$}VP7-`@{ab!nOfyMo4@%Z4Q`7N;_htRLC1Xq|uHy%j^ zqu#M*xB&_AitkXaeV2@4j(nfZWYM?=R&pV{Xx!Z7dR#KeVhNvh72)m|zvHd;-3bn~ zj3Ho`POH=;EkMsZeCof}{-PS2dwWj9dQTrdb7{qYB+B%ek8_Maq!4 zx{gm1^lP&mc65lX;PNiS<2})-Sz`2T@9T`s0Uiwi<=j6Ov(3xjg&HCkGqdwa>xqD$ zSpY$y(~z2hcYGb>I#WtOp$cvC+5N&m;Q$Kf0Jdrtb3NBLw*fYD`)faI#LS$7MLkx? zaS)AD84bU(H>fcM*pu{0^TshOGL|A*osl&asJv26*8S2eZ`(ZNIwMT7I%K7d{?Oyi zd&sh1l6ZGi{{reI=85hr?)%haPS&(>G9^xpT6#d-Ep7$)N1wj|S$agQP#f1pZc98> zW*k$myPP|YN(QMUQqjV@nDuYSSdX?{Yl*o&%k%ATd0#JDrEV{9uh;9XGM#}`nK56d zA(cY=7eW&3x9?jHQpd7&x4yxQ7e_v!EATy}{N}1&XfSkpoP4Dv@>6_*Wx=(rd%Ne! ztPc+Szhw!wGH<8AsO4#W86$VTD47UI`8o z7D?pQ{10$Qo`O^_mjcC!P+WXYQOVq1-=DOJ^=z}ck{+x1zwm4V5;-sGLh@D5aK@X) z|Dwgwt%BMDai_7|wsn*rMRc?TkZrU6A5ho%vw!vVBrC4RH01JvTR;t>S8{W)cq;~I zemq*ql%7NSzt6{!^RKSJ*3}zPGn;zip>R2o-{BsGXkIzn=C(-HPR*i89w?M}6vgnZ zyS>vnA-6{Ro~=uTJgq@|cLr`uY_rz6Y%FgJS7W?xso}6EV*yr4&C0p24h>f9Xf$Uf zpve|x@vO9Oq?SfgU_`39xg6wp9yO>CsJf;eH}o@@9UEz88RS2AzBs_5u?h!W3T%@afuE*_mfPCdpy|jwSTU&2q z$WBs<^F6cej85XLyddl#pPcu~NK&a6YJ@!OSJis$8|p4i%NgIE(qyCgC*mPa>V`IC zmW}n>BKsE|XF9>%yhqW>k_W%&7B7SE)VFIp*k2$bL2+L831iJ#rRcKgW#dHJBajv2 z@7IgK(SJ>+#BP*`YCm#yeVqw*V8VbEjS!XM>@5?1Q>hb1zSS1GNe;PXT#mA5Yu{3 z=kr$*8uX*pj*KQkDg!LDHb!!f5w$6KpJM%TJ#oA$d9w1$EMTbk1`eD!|S^fT9C?XNZD-4?%J zDbhN*8#8j&`a62}MH(hMNh+d&O|$XdcS@z(wjX0=AI8rm6@gt(6j11t8>O- zFR|j1Y7Ga^sg$pinxvqOl|-?LhEw02!J^wzKyNM?FO20IW6;~xITB$6m?bMYfC`?v zR7@vmp*{jVAIcrzut8$1qHz5oT)S}wng4c#eZ*JzLf|ai<(;1R+%$5+5%BYh#O7kF z^6odt0Uw{?#`O(?w`+lC9G zLbV#;Br&Q{X&Qfbv6FbjU=4J#-1nM6jP-&WoY%^w^02BJ)vS0Na2j+3mEjxH$DXj> z^Kfm-=WpJRO(zson)~Flm>SJzrmy?!xtZmLEb$naejDj@jcdmd?DR3+%X76E0K-J! z+**5}bK2vTntwrvsoRg5-Cr&gpvKe7Nkz5eVvJ=f5g#uz&#aXL z#SZ3&uGnSQRUtP*)MnpPjdP*Hy~%uZ-^-8;a#^ypN*=5cUZju2g!x$3rAFy`j_HLG z-VIn!X>dUAPlI4nl@GxJRJ>PByh{G@&5?qUWa3$~{`Uh9lNO;4us{BLyO3{EyivTzc;aCe|PpQnbBf0ORjf#5_ zH=A4^&Sw>$&RLc=cO)0g%AH~$7S$OCIJ)X4kQVX4e&yCKu-K1Okf^Ciltefjan`#A z*%u^hFf4Y`4pu7E9Gw)G)QtSz@5lD1Y1)dB9as;cq~db>-lMW!4B+32uIR)S2pX3K z(`nzp@`GYLrB~oqyh=8WM>Q3_RGt!dthGz8vFa~I1c(J^T0MGFstmA?Tkek?hPh{h zD%O!VAK*aLd%JxV(>igWfX66iwur|~{Imo`gTyyRdQ^LT zb0hPE-37M%l;stMO-mh)UT+{tM6$IzZ^F4~&*Po2Yn6NjR^AmiqhxMJRsN%h{CMeD z`eyT{)0|RxQJoo$cn=->TPfp&YJu&BcTQ-09lHk48=@r`()&j3Jo+2i0JM?DxW zAL*nYyN`|nQ(v9l*4K}I%8XFkWpOhUIo*dh3jw%G4c?!$wk|nCbxfWD?%K%_kKyQ= z2Cne!+St3Qmd!{3q&Iw*))&LP0P##Px_{>b0kq?_f=AjxkY`YBxy&O+J7I24_`>L9CK^o7>*-@?KVu=@8?ErEk4$ zxIq?Pa|<(0@5PvaTA|zU1L0c`6yA@tuIb3yHjI=^n`EB3r0#7R0_t>ndo%oH4(1DA z6^5>ai+3@Qiep({(>$_!-UQBcxeOM3-aDbqXkq zc()Z5m_q26nvPG>j#xfraU1ctpKsOaR@o^nZAN&&!eK;TfHvHE&&d^OoG&PIyO|i? zHa^#0XO=G20=<G(`cM*xOQ}av@J@Fg zKoUR@sj9a%mPGOfuy4F{&=%0LebE2g{(Wl zv^Oc1H%p@`7>hPilO8Zz{V&cWT%A|1Qt{F&cx1NT(6TiUqeMnl-&dRN*zZV?7 zsHTQ9?IoA-bb~Z9i7Bq?q3PtXb9j$%PPce5FjcZdR{vB;Uc%sqW)^xqbP7AsDX7!< zh>+(Kd_|&{?(&5E1w6#sPDgxLCtsdOc*J9gC4R__?WJ_o_To*IdpJ*Q#iq>Z4e)HQ zS+nk?bj3zror=I|&u-^G7)yz4*H$KN?~f$m{=-tJ0_{y?VdfR$XFlO9jNtV*)@LdwH2}xE)(+vIkt^n$0L3ZRX(&avU z+QH2zA*$eO$cdNP3*m^A_@ZCj20otw|A%mZAAm<)6y#pW0EW@1I;p;H7!_gqkPQ;vo}k8g`_T-)b;F{e=s$*0O#xXz%# z?#Z+$+-@LHw2gZSvL3=iICb6k0*pw;R&4WpLa=!_OAX2m6Sy4m87>%#b_hJDcYfkr>_pZJl^4SBMBTqgj?XYeBR%BBOIllLlla!iefO|oNe-6^<= zFGaIWv8cW_bei@5A;TE&)->7s1v@nzo&;`UNU`rq#-jJU*Av0i(y@N$?r!_4T_{;B z%=9r%>4;?&UAUU64z||RH&n6h^H&4ne3GZ;$Wr8_Gyc(+9ThUTj|DN7>>8U;M05mO zxY4EW*FW8T69MqwwxJ7uS%H6}LTDc+5?cs^Qw?^~e4j2ddgUT%b7v>ug?d-!y-a1G z7npw#pH1YC_=RQ(TNP{xDop$+SLjvs`tmHwS8xFebZ$A1FGXP+ddQ)V)UZ_%Ag9&n zWac7#P$wd1|Hj{PmykO62w2iP-QSwcpl}IKH-CO%x*#rdg-dRo5$j+0mqqSkMKJA% zwe7^L!k%`nSL5JN7cY_pF2!gJFzsg4=ekne0AHQ8u&71VTpnWuu6<2ZYuX?}JUDrN zIzcw835=Q_Ya+|3+>v-dD0;uq_g8*;XSnETV;%L?U_3yPtyVEr<2aSOaj4+#LMog! zOclE{(n~@^eUU{&Ibv%G)ucZ*N{ykVilVVIr$WzpWrx?mnJi7tl%X8&Bwysw;%+?3 zhcJN6@C_4(8}?T+BAKqWR0uzdH3v^@(YnaR5QtXB)QSY&l>$yaTFKx^n1oQesJl<5 zt`Poy)1wPl-Ihzd=rpsM)^ThZ8Pe=NCiD90N_Zg!arq<+k$l2mgR4LX#8rU5K1b4y z1H|6dNm1BKv@uY>HnOER&HFG>D<~DM0N%^-LsvvoyTm27lek;sV;L>#2p@Ea>6>hK(Let(=arpGeJY#2bki1G^x6;tl z@#>LMn7zQT^&DIrz>RgmXz=S4^MQ9c2&?ok{dD3!>KS1AXXmBF{)_QJDeVv;R<7k#y9^7*VwE{LUdlky}qSKkhWm>-$LQA~@A zKyXsJNb@#0{o0sxd?}S^t;GH>7e3uOrbo_-&w53R;l|3kSPYpOeT=vs5vE6FtdAOu*^S5 z0X?|_+#D1Vu8JstecS>cb;ls#t*vMaR=la2kT0h}Pi6Io@zeZ@x2bCul{6F7Rr&rl z-VL5pq2tbq9un@RXXg&1#qQfP7K(|U&<`w=mI#8bbGsE*T?G|WBq0&S!p$S$mO0H$ z;NVY{8UKbzbZUwuR%*y5S-;bLs5zmRA}LRaS_{YAu6phgrJ?H;gxlz*H=um!(>Ht>^$*2K%?TNYdb!A|rl+O@6GNuz&-t ztzP|K=SLA=^kwzp9^s-47M1T)taD@eq^sqc zxep39%1g@ile7aJ{%T3spfoXu$xAAcUfsXfu}K__Hi9rf=!kmoj?aI+BoF`73m|9q zljq|r5)$$e4Y~mYj*Zjb0NX)niiBMPxy@=`LDS15hpyjJDB221VpCG@KZCzJu`850 zo-hTn^<;Fc!8!8P=9?cz^Dd1efqc6-^edUEFlMY_(Mlb_Kl}2^OidsvGvlXzOAM1x^|&n|8QMF`K8O1S8vdP=%(_OVAI=6z%8s zF{u;SG^rdamR%OR>(1mS6eD%~EUtrV+`5A;0O=QD*t()54Zg6RZz1}J=Amjx@zymf z^jih|V6TYjtdwMzv~-6#RV$Qc^IrcL*&5q?h%44Y7RWm<5N59!f+B+SkComA-)>iP zpUb;;i8V`v9Kr0|K=S!?DkSh`E5k~v8Edo^Cw?ZNvPMxA)#f&|qVD}x?C5amh=;2C zX}#HCl}YEmurtI?&5)(kQ-@0MRlBxu1&bEPpP^E4wfSDh`DvQjptzA#V)J*DYAxrP zYH5}P?*5p7h%crU6kLW{rbd^fcuLVZC4teQ#-!h|{zVog7Y;Nlpxgq=4I9(wmMj66>$0-=ptDW9F&5wH1 z89p8ilKj5Dlxce;_!{V?)+Ac{5)lS62<1opXl03Dvv7HN8RS^))s@+4dD481>xJ?< zjifj4k+QXw7769!V1Zd?mdGV_RZmEXsp;CmWb>Ke?N}f_1)kgKl=D%hlSLJur6iFz zpF|ZgA!!ITmL9)s`!7*H#AVaP5pinRN7`*S8)@r0sb9JPcAMu{a| zQkU-_s5FlJ_0Zuj_=i+`JZ#!Sk%tHaX9B*Ocn7Ut z!sAAf@o8B`zL>bX`iHkaR(IuE`tdf0HAn{VGr;6s$JqCb4r_Q~sI1_N*@l?Q zuvYF3-TmBTdYUTz^fz^DejCr(wMng}x06<)HlR!pB!`)I$xS?NvsFJv!T4?d!w(|O znjm@FRi`j3S#pnUqv(-^Kh8pbeQqrhbPn1HTZWR{Z!2q8-s_A0r9Hm-^E(CUC`Fl%J*YxDDRggOQp52> zQs60JTxHExVozuXJOh5VuUe;Ld61+*=Ds?;T7E=2MIeWRdR)|I z{}~xbup+|qGB%`z;x)8mEIToXfdc0io8GP5?5Z7r;?ol_ip%wJ5aP1LMCfj?>{8zV z_G%9kp6>RAI=R5U1S^Qe)YAWYtueEIyU_nk;reJO|PRTP|;+C zZaeXTLu6xODR4E^ZYXENN7FvyrmN0%KZa8pM>mFBh-QTGgqDEhm?Fb$_bU++IR|mx ztNWSS7H+pVC))|;yJ#n~wttbBrfVfQ;Oy)+?yc&INE{^qFg(@@{F(~3hpx@F{$*?2 z0#S^`y^ zwU8|>LS>wF#gA(@6CgRI>$3H1ruX@@;&q?H0>O37DBx?RaZkd#N4&UwNvTSvGu0mjt&}Ob z@~vx{>k?*`hbl73GTg7$PFdgjUwxcZ1^s_aorPOdapSk?W=JTbL8PP^%?Lq3NdYCK zrIqd)9f|{_yGuHy8>G9tMvU$n^^WIxuj~2!3%ky?bI$kv+_wg=5gM$QN7drCl$s>- z_sEzbTJ{jCT5r{`C_9vNt3ak5=RICjl{kD+h>;*`=jm2|)0pm7D^ZP*rqib2cq0A5 zNtgyy88448ev8WmlmBr(vdWhOMWd!b2l%Z$=V{ZE;;lD#ST!kzmfG_IcgcVA$*b<^jL61|GC8*IqZZU5O|SJji9VlUGZ7z+`jitq~NOtJeM{CGF=NI*?W zG!S{5#}gAC1xx(w&);_4Zq!ig^U94<%*<-e`fADFyE{T7v7rZ+CLc_`KyRTS)%Wj29k2}#hzul!Sw8BBwD{Xze1$Lx z&js@^SLr#(kgr;Kx$Q{=xSrdUq#}<0Dj6=I9)%g5ZczTw5Z>@pPAsAEhOcdo$^gRh zkq=NqWg(?hun@`fwT4Zh*So$-TgBd)f9adVGTC9aOvaDsm)wm5)aRwU5%5@<%8s=qBFURRm6z|)@ww-8TWKr)y&rkp ze7uh@-$rvExS|pjw;h|oiT?IxmtYRzBGB`1(WOfCAMNM-3Fe7dwJEik`}AFpPR_~u z-2~88%~;{^hi|FhOui*zItvzbZh2H-v4zFvL5Zc+SwZ-N8Naibhke&cUQ}U8Z(6PK;d%B;=eNMO2NxMuwok={Kcyk`^q6k_?b#QfNILUEwqhCI`EUf zr|4ntA0Hsgo4~HQpI;`3?WvDp^lVr%sXr^6KT`H@`ud2b6>k?0Cgk#;m$ujJceOj1 zP-CPd<9dPp8XYcxe_J2+TDckzE5C!^03_z|Fn~(Iv%)2KIgC@w3}*q1d$MfBdIcU15h zB1!v;gi1&WO`?#?zys8LoTfDl6WN3aS;pW^q3RmJKM$(qEb)qbc46! z`~kaJh$I=mh=CE^^$Ac2Hg>DHyKy{Y{H}oxT~P0;;VtHJZ;yJLeqsLC-!If4a&m5^ zYDAo>(%f`^EPB!5Si;}C0oL?`f6W!2HRdfJXK;g`8as<|{?6CksU9Qn_C{aYC-$tGHEaC_A`)TyopOcB+yFEN;keji}ZF zRNpDVfzSI56R*=e#bm?mg&d{Roq93ENwf1JI0n47yQsuDNRx+nr_D^ro^fJIFIjim zf3jhwNg3V+r+B~Ty^86);wi70Fm>p5 z%5{sPycqF!vsbz}nS-a=4da~yP2k-46V#D`1HUKAEN^iIIFbC?DoETtaX1;*tjr_kZ?3|0T z27J~73jrV*C7yR}ntrsRl%zL8-d+OAe4$(OnISDA%H{1k27PrH_D-Mi9yAmjHa}|y zr+EN6T3GK!{h>UQZw9`kw{)h9w|zH8!9VClzCm0jhGBJvQoFa?I-x~_7_u#{Mgxd2 zOU4KY>YQ1^7c=)0G+ms;>~~5V*%m`+C#7V|<-@@`0bAE(t=XkngJw9;2#H;I<&nf~ zuvH#mUbKbsmJpQ=-8glNGd!@5h$X_+%^0(5Txk9ap)vb0@`=b?a8gBe?u|=qwz8*m z%>}k(lW*0W5s#H(7jN=Q{X|s_adZhq~iSWFt6@UlaV=_&YC8iUwL= z3jY1k?wEqG8aIb&~)#z5TCZ>+3_*y9h-5TFh(#A3>VeUcpoR} z`*0L0DS_Q@v$BMdSrXRnIKQq%XfKHn+Dl1_5s_zaLGqGvq-%biqaGG)@qedatV1!v zCY#h$j2%Y16A3u5AaY$Gz;d@!9i1C3ZNLc;!rs(C>i6o8oznK!7bc?tcO(!^m;SsanjJHBuTYP6|I_Y?K-6Xk5Vi;ecMO&LC_jO;z zdi)%$;XT~iC57l&kkzN}42F>PK9-lWrdDIWp}eQG^Ev7p^kG8xyk>V7cYzM)>$8}{ zYad^ypP;C~!(%be`&Um&8^t7>Ni8yEe408)-w^2^HBAgd=&s;3I(z;aRAgq)?FIYl z7>_r+`yltou-ni`seX-*b5m^h#*mQ#Rr+X8@*{XDOmm<1z772{7RK`Eh8c%&laV<2 z*t&d6kHUO~2Ew%xmFw=~q+XLL-SGUO^-QX?S1g(P5A#CJ()k2k`V4jn_E9nEKw+}t zV-2h)uDNERW?$W)m-qH3LZ{aB!gb1dx~bY@$LOMV>H_to(jw6GsKKPJSR%36Ni~tb zW>~jov2b54b+ryJ-^Qmkr%DzKm9hx+uY=a(B>H6)>sZyl6FHoP_3w*RF*;Ow^;s;y z&Y3HbrLZ2AaDdo(x!rOK6ZX(EdomzVp7`=X%?ekXG|)s@i!QGIzGfl1CyhD?J=c40 zVVP3;T^+MByy@bHRL6%(t9;10q?zI%7ljIld?1=Y+^_n4eMi^-trU1nSxXpD#amZk zJVL4rH!2_`DARjK<^Ns#qg_Wqc5;RGn^ERy4-74Xz-0lqT|N^83;c9w1A{yQeE0qd z0bi9m4l|$uuLg#>?$=e=N5~k6+Sb%w-LIl{T0TagU7Vj|d^@3z9hd;aMfjRDV+?Q>g(4?ERMJ4gn%`!iwt zkA39BZCY)-T%JY*q0Jw!5aV&3r_rPUBJsAE(*m{Xzx+ebA{p+~T>E3#7t#ig0hLa$VXi^t&`vg{en!wk?;qYE(Bslkw%6jdJ+J4gRT!$?xhcWrDHQbq&dSP2 z1bsz;+2io)i(6Y{wzN=_>ok@ZsE9ILP*wl?j9uI)#?bLBHx|Z7pF#KHh;46zD+*5e zn9ORd$SaRrULaWRLo$L(B;vAjprgSL#T5ocp6h&)Q+t-e=SBzI#7qV))G1O z%qVa-@_UCSH+^NnY;+ni!Se#wYAVI`6`+J2?$FWA4K1$o-Eb8&~s z%F4|@h>r?-yTaeLQ~L`Vr+WX2Mprmris!*5jy-aJW|1IWk0e-uwU0nqkGJL;oeb}_ z7JQ!TR7s3nabgEG{>vo6rmokaVbseTVBC}0uY89 zd-S*-<4Pj=Yo@P{VH-d?;eb?rG<>i7FL`RryK3A$1>xSS?{U{M6`y^cYVz7{T={8g z*)N+JmKxPIGCQW}KotD2x*-KRj@co#-Y}O3pK}-BV|g)Mi&p#PBl`8SG5^_g40zc2 z=AZDKuZmVrFvC@QEK*LBJ}kh^`zrai>qY!P=6!9~^}nuAwRQvKK-xiWN1iKGvWf-> zte)A?_h}Wv`I)*Jx8?Y;u=ohsZY+trUS4!o;rfH=g_qCHbnaq6*PYrv@YK@|V&@3X z>y@`2H2R$@TFLy(*nlIHkFqb;65K>th)4x4ssIUxOkRU0ta9j^!-yCXo5k~o4>aR1 zd`62~J%$96cLm&Kj&f*36d}@7)v^N{nI5NqWrS|n8b2_aTFOm+W^teD$!`ej^N#j2 zWfxtcf8R2YozAxMGo%J^&y2ivBy`sI1nmLaLlJ)AYM1ks!}_vEYTX`ZgVjC#jQHd+ z^%t3y4RAe!0}Vni0iaGe#b*~vgB7Bde_N*gT+#is$umEiwj#VTId$0do-GoG&gzhk zfpS06dqj@`OT%h*g@LSZ>A#S16=-OXoHv_E2K#NtpEjSk_# z33+H|F!}j?9aE*_WozP~i_9O??K^hpJ{^ZO(0XR5Mw?|#jMRVgv?*v> zf@Gujgole@Ia4ES*NI%>?86ZnJ#{n)492y$G&~-Umg&bS>IHh{KD+PPqWk%F6rw_2 z0EmW6(-E-Wn4Ff!RV7svsdH5=ZSRv<9#%IKK)!Du&G2pttT`pWFq?B=zSnRY96x&} zRR5AxINScfaq@cCD{cy<3hSf2f%&RP3$;#FnX`N9DZ8D;x-XUq&{l48hN_kRWXTzq zsQHysgd;b>>5C>igH3fcp}idzbU0Gj;C#>UD{)jzBlI=cT4S1P!AgERrgfBo|IfLK zr*E7|9h=xj;f`D5PA->s&VHjxgV|DD8PR@sr7csQ=&HQ?(B**o3`N?pgHj6!=Z7rS z@d*B3SEXkRJrKEjAo2?qHw9Nr`$%}nbrEs^3y!tkC0PLdDn#YDh|NCR^^BJ*W)-t) z(bhwxeY{N7y86Nr&aIwjoQ{;W1eA~_t`{|y#)l2TO05F@qr*q54)Z&j>Z5__AV5+yxG4kEKI+i=L zmRx2wUS{DiN0M#N+>#Ft_`ekA#n+CBQ1w71t9N@DI(1b_Wxb^AFpOWMW4!O)pj9lk{ z4L>2%e|v#N`W^j(CjNuJo!J%|^tBE8XH1bSl=8gjmA>&i0TI(*@)`Fv#YcU#VwL+E zsi(z37u&BH1=)X*bG63{q^Sf4v+Cfd|dr-16BRBNzaz}Jcu$y*|X;VU?MhHl)g@()zC7;_442CP$hAS-3aaN>$5J zKbyarOCG%>qEB;Pf(pXoBs`a_%UxHePH%? zJgEa_b-riO0VLo;X8^PW7?yAdob-C^p``~NEE=Q_oi+`7h+thM#ybr*a61pX40-g0 zIP)$)ASP<>s8vYgFPu}|N`CQ{mGM^J-x;YH>h;=mSgCzlE}V_GYGiP>v;*yJ{ed-z z=6fjTe~EI65`Q~-pjpDEo6*(}g{MPbbF0Bm@v)?B^GiK*|8Rs@5FWMwsX9WSi^ zm!~d{K|2`q$lb8S|39d9aDLmo{q#+8F*%e4l&Tq5x!jL5mn4f~&d;6Z4X?R)J<;9- z!DsyBpS$sg|S>XJhJs{n7<;w!(k)abb;axE2 zBBZkD*=*Q?=U_d`ZFJeb8ynPe+a1)9{OKE1U`I58_givLaMnc9`AWL?z<`%|0hP#>|n^r9)7Sl=@JTqOFl|0@hnES6kW>yh;9k>k4! zkNT{ZU3+)u*#S0d!4JRyVHVORd53a>16f9W$GIXa2`3b9+TDiBmbt%6wnDTr2Q446 zyTs41vObH#^P+9GHl1s>rIN~soFn9Jj6(KZdCc+VH*BGqGjo$+*sF45XEQy}usdLT zIY;ZRzq8Y`CK91@qSkd9@)rJsa_3Y-I3Huqmty@1)@e-sIi$+q$Dm(mOg zg*d;VnY4cW^~@L_xBO>FQx}u{faCPalVu<-3k`n7RH?Hl#`%q1%>lU7`!b4 zzc-$`i<%fw=u4$r(K*82pPoP$|%BN8PV{tnj1aTbYdHC3#Ofb%lUKU zYx3d=hz6bHECpch_VBL?#d~mDFur~+bi2njiDya$tOg*_s9Hue(B7pv8-58&1Ut?T z$dK+e6EUI~Jpb?zmDE~HJM{v3JQlc|#$G1+!Beh;M2^KZ*N1pm)1Nv*;?#mQDq{3$ zh5ux@el2kjLSa?Jmo)1;L&X@pAjY73`x{Zzh{Z}!zzn}w4IqlMu3Zp> zKR$7>zqlIy{D*@PJFGMII<;}e-e5eJVD(EU8)KpNIZIxN!KQViZ)x3UH$)EWzk6Aj zJaL_R$LP~XTS5=P(CHUbcU?j)+k=UuD+DvPlVx-UC$_9iZXu3^MTXgXl0PewL7vL) zHQu~s<_;=}NO8kxN>#^zfD-&N)z*jQ{Oq?3VCU8gzS|*VjP*hizBOoW)-BZLx@g&h zz+K*M)sM4d{80#t@j83R457u7X+e_KvACoxU2?inD=jP$Wi5_YP5z_~2m9G9yC3>T zL%g1V+Y^BXa^byio*01L!6Xs(mI-BznODL;0U_6PhyYs0^kk~@gzC;5zAW!vwjE2i zOyi3bcQ-qvOuC1~sCLdQ`I_1+&N*7<0SWLDSMHN`7+ra?uhnx8JgvtiIPo>Sa`Nz= zEh{GrmT-J&f@9L}pd{=A76d-Fa$8gbY>}tsrSnlM7R+7LHqMhLMp?NRcZ=(lPM3F; zZkJ7!b>6g&=YJ5qM->-hy%#{i3d&%bU7=cGtAke_zvWEMAF=(2ACcoz9McnsLFC>O zJ}dvv@+BW5S>z=p<;VMvZF5!T@kpe&STPMOvZ^lZ|8@t0e~rWz@$+s4caaVhRec zAi5H15Z!)R>9P}*OUxQgjQS6$B-v^?@)ByNe0p=PbBM&}9d4~_}! z%N<7FU|$S%c&$)@l{FfjrduwWkvKK>?2|IX$&vTs0ht( zS<*^0qh!XG-ySx-{^kq$l1e`E#%{^ZI=_6=SlZ z`74Oe)sq0{?M{BWx5u-{H}1Oy`O?JbA3OB})};2EuqGmpQ<>b!Mw>xv4cF_x29R#> z%j>O5qO@P*NRdxvo8Qat*|uo1_G~v&kasN2_jNXxd!k3l{*mrVUX2S4N8ADnmzphw z?hw1zokLM16bmMz}ZoEDvO~u)2YOU%;vH6 z1Pj9Lm)*C64szJx9m>oq>Sl`mie076m-9Hgy@R$sXA5?9Pm-5WWxevg{I*cS>$wqc zjVG4f4h~n)FKgH>#CZnKe^4wHHe%1XJ#4g{WiE2w5F`0pxX*+7JKcCpWqqcg&ifJ2 zp-rVK-m8U^baXdulDMEhUVUSF>FgRGqtqcF z<~)>w{BSXNVc2Neo7WOpgc|*9ulrr-?Af=pyn%(vE*ZF$_Wre#Pg}psfEkX^evzF zjxi$J=q&SZ)$xvne9nZ%@?AcdUo441F};TY_rgG>1HHX~XcK;ei$cwRgBnY7Rc0Y) zW-(Wb#v`k#Ood%ak_cV>EBzOP2CFqhv^TJbBqL&x`*0o`GPeTlI=iDT9$M# z2jF-24M2GM&7k>9Ekj?5PzA2~S7wiHf*<*?>eijZNoh?`l|Sg1?&g5klfiY#Yb?~D zx7INfwdmAWtCS&cv8eku{qTcDPupgc9jxo)6zOjJjbZvrfi$tZ=@Ln#A@X8ZY!i0= z-`7_Of@K~LjlGF^v{q#oHA7$D8y=cj7=6gP5J<=n>yL89KyAVHDZ>r!&qD;-*9AGW;*N1haYS zD&hiW_N!jV4nY`2_Gio@fn)g=Cm8tCQM9rNkEwiG#{IO8ZPMsg@cE9P1-C z?)a{zvHJ&g4%vUvrKkpr3zR`q8>_8aOsX0fH3(%(s!UKr_!_~DbN4QP2kMNbj>G3d zb0K4{0@2qAN4=1bAU%0$?E{@wq>PskbANR*kyD*Zfw2{d?fKg${J1{ym}!(2X@`;V zD8DjK3}&o>T7L-ReO4mE;q!pZHhAvOr8ln-NWcqA1mIcs4Qlh_5|uB3wzB%K znzqG_7C%POcTSvAh45I8=QwoWvIJ%Y4E6Gso1s0?&?%lC9{s#M)THGubMEo1EWO>Y zPtARSJRLSaHI5DMo`89D17)4_USVw#KJ}>{74OsNAAAS5AaLBlhR$#vP=jcdog`?R z;s~5g?J=8adUlc=Uw^h0H}Z+aW&I~hw_3DITs72y58$?gzA|d-Kv;d=DPl)P!k9Ox zHj8vh9onsH>bk$XdE%^lpB7r3y&cKwqrexZgO^vYB8UftwV?S;U-Wcf6`NC*lfYX0+p{ouO{pbQ30snH+Mcq%W62KVCC; zrxVSH>HORyeXiVI4A)4NhBk1z{q00_Nt`>hlO2i^M*y-X!{eAS#3x66B?X4-s>rV3(2%oo_9Hls=i zTMiBRQD3ihbcg`6T_5q60`dZYj@@fIABkf>qL`A}c~O9K@g0y(&*QG5XFyr(h|H-f zK8(dFVGvigJ)IXK%Mp~#E~2(Pqse_rHO_l3*ZR4ylO1NStdnp6=cs??rPJISb2oI- zkbgV8&oiGv`tTc68JGuumyhx736?_alh^fES72s7aj+Ci?ipdUk>3RgRHSQ?ZSVeY z{kU6k7~sr|V75>bN{g(geI%II?)j?w{Mg938@xz+v5#pw7ceoAr;G)W>jQ6RsHwKTUHLp-`E16hE91&#hO>sEKljsi zU_c?rZ6@r%d;u$?8y?_`V|vk@K8C3o2yI`mZq$3b5flG)mk2x@pScR%OpNs7UNt-y zcD~oiVtfxR^rrZT>c_-@4Pz$~e;aObyha2Oh7POvASU?!ie@4QR$_rOofKvegd5T% z2Rx7NVck)fRN=*VZC6cw{j{a&{^}mtK3!|?pp_DTFBm>Gy9ycNE@fks3nTdFPl_$Y zFd8_XL%be?!M@734(GJjMMJG3#3G>)dCT9Sqm07Isf`O>?W(X7=8bMtkTQfOy)~GzYef=CxD1NJ6=0MdVF}ydJysfek~SqI}|Rm0TsLe_tLdn zWEP--`#RLm`Q8ZHo1cgf*-jzthH~(}pyqQoOxVxPBGMHmH6G=bF}+#RAlMGgCm>Iw z@t$34cN|V!>+m@zs_2R&qb}qA ztQlTnu$;27eu{&PMm_;xrpt@&rpV*4H~EYa!j)F_Q*jKA8&ryKrXW`<?1N0C<$n%86d*n#FKc7>pC|fpl_XARl+mqi1)h>BNj=he~NZ2P;emIBPgum+iNG z8UIn)QCpb%OG)lDNd1IHvqaPK=dxc*0eX#@otQjL=GKyW$GA;5-mS7`;L&m%4fdud zFI0R#k9yF?^!+6oP8f@|R_qyB8lcrud?)}Y18p&}_pj%tv!BtqRSy5jkBtOvb#r*Q zIZt+KT;m4HZSMJgPgcEvq70;^>0DKt$2^Z_jx%ah;mCF=8>ZB7Th74mW+4?9?mj;M zgxyi6-EY`Lp?^g(yMw`t!AJuPUQ40OHL{}c=uWK2ZGtvy^BWH*JPlIUz#_Jhv`nvr z1ACF@%IDe7UR?9OG)u(r_idL+r^coBbUDHeF^}u*1EnEg7`P$Ykw5MaJr4oc{bT*= z-MD-|;!cE}Ac=244+NM23R+=nPcF=2YP+qt2%KucxMfz9yNjXTZ|pogX4OC1%=hc~ zAoJ&^aN~$zh5;tAYUh&h6S4F5S$kl&f<4d`R}!kHh9xH(*I64AMh@RMq{js9b_S7> z8LkNlQ{#TQoAJ5VTn!gOxbKhkpw&pQ(o=r4@WXKx3py~{SJ%>8YLCm`n;ihf!#I}8u$YTe#f;zp*&Rm; z^SRGwip(kWK!@!*B;SzCN3+A)Y}J9~18sebfq3rSVanZGl;W5-acOO2{C~K!)Si_$K z_en&vl6~dMN6ez-Tn6LFZrv9vdSd&;qdT{v#(WLlp{ghf_cBXVhHa2ii>4@3t1unF zDItDk8Ja7q8ulR6qv(IVdOkkLYgOz$*WQR(Foq+@2V+eL1N96qMJGIHb*M3`7qoA0mjd2GO+i3WFB{Tf4G zpim4>MEP|M)R8q_zH?v8;Xo(hF~-u-yN3s~y$&Gl>0$6gWzHTp;yD#NEH~a~dYvu^ zK}r|8@v#pkFHD$CjhciL^V2;VvguyF}5FXxL>sx0^ZPshb37of=cS}>{&#k7+vaQbiVRizoO=lf4;F0VO$VV~( z+L3MXaupX5ScgM^x2qP#2AgQ)7ufRMeK(OG_|5A^No~{%Va%Ev=$s%sVxC!rC8Ke` zoBXF$60z{NkTkfVi&fKcha}cdjsoLBYOJ(K*O__yQ4PTuo%(0Ke8DMf)O}W`LvFeX z-SpSQT^RzDS-5hq=)~bPw7mlvJUeVsq{r_+N3;9EULI<0648#Lqf$^?LB@n`NW|NS zoFe0IyeaHf-(bhhIhw^IJ#FCg=HdwZPHGd38!Xq<1y zH)J>J1Ap7V?;&~8B3ui6TZH*dDvg?CH(i3O3Fpbd5Tyglb+k<1by%uC{x_>gyR_QJ ztxGq3)b!qP58lJq_&1{rVr>uh)$kl$sjl#f?CZTR^nuGP=?)GXkt)Y+E~C7BrKN`C z2ACa(CGY)YY}ORJ$+^A&9fmF!lCou@9_|IfCV~#rkLm4Vd!dfoiKD|o>E=VpBKaK2 zVy)1!aQExVNawPAK!vQjaUse~y)%s)6TG~^NxzKcQxAU;<-XCyTGi&ST`Pig>~dE! zdi7^ze8G{5|GIKa;yhAtgB@Adguh2*%ZB608P^1Vh=I00AQxsYx`Qf-!=1+G=wx># zZ^wxkCJ&!NZN_nb1_!qc9xn_!eLq^^miFVAfj_Ayrnz+_+J73(2wn{bzG&?;!C6C> zL=C)5+)nZ(7Csns0*3$8Q6w-XG_Yps=+6uoGfmb95a~8E76Z}q%m3>a-OY}E z(389$oG_i<*MkyQ=m8>5?Kz+g8O&jg9;>gwGXS3B zZY0i?l7YSXw>*h3i6jCK%Y&{A81WtK2i!K>d8popprQ}ruJJs1*P98?ypS9-OS_ro{Zt)ErPG? z<~HMF$Bbp=GKNGl7xLncDI|+0-SUj&31xRqQp~sC`r0(1v9yb7zq?f|of3+?)>ukA z*TwRP;=IHfIhl}umdsK&ZiX~-8ckd%Nn)(-hj7X`Adojzq8jWa2po$(@Iwy3G9*Id z0pkYE(1qc`lt&VCsDg0ouIWlgomF6!_C=*eYzv8<)92qqvNN0McxB`^|FR3QmVraF=Bqh;_bxk8r4F}H6mWHHuT&0Ht@C41Sx;aEw8vJ*J6PFIcygB| z9({pSD`w+!$-J6nShVPr5*vMg4{payCYF0;m%X1-8qy0?Rj4!TPZ;sgV|jzyq{7_#H2bk)w>29Ndzn)(hSagkkAHx$pdtVpa55Mry)_-=T?HYd32@8TvSFPMWHV;VJSU7A%yrk(4&DX;n+b! zOA%7Rclor<$uYk-se>`#`!C5jo61v z+eg8T=;c{7zBiSr?^{kdYn!Y-M4}_VU+6XFd7tu$1RGykbVyd2furOE*K)!Fg*003 zu#ku2X}{B(!b7Q6M}BdxeWNEn%aflqVicY~9yC2S%O1GKz4$P9GW86sLnl}?b__=# zUwU6`Gb?=iq|2_6zW;~FgP=^#d6vC!3@DK~yxu(p4}~3FCrq{Fs6UTR#@;qsQGTmP zG33^i(&mk^<@s!obkO^uqmK28jDD}^(&k@0erDmVEU`1=7%RSjU0n$ ze#X$JioGVb0$!$AsJk`VD0ndJOE3(l#a&BGf@F`(TvSEr*{xx1FK*#V8?e4Gp zqrFj$+9I1qcWQ~pZJkKaIV`E*rz?ogtJ<+kVha5#1l;k{(R*vhyaOdWsA84r8jde{ zh4?2U%TE*KRle44^c~3_CqA||-x5q06;fiiOo-urfMeMB{@~9Q!jvx8+=p<0c5mOL z+H=5!v$H{un)j>FaAfP7FC0GG_{Z?gr*hv}l%3xwxPeiDzh&|6W^5&0^^|GTl0Bo7 zWgRY49Y;hOi<>WZJ&zpII!{29(D3)Ey3R0sxo+%bi?__2t?T&puQUm8>cO|gRqSPB zy$-1kZgz{BdXUeAy-=0bA@9}nKIifD!4K{uJ0IYR%HfiS*P?)P7RR&jB+VWitm`3( z(|N6466#$7W(!M$rCBIGGCC$&$2H=KGgO1tStOzN34kSP6NdVUF3kFVl%hH>`;=34 z;n$@U4PuD zaP$@bIK0E)j$+C4&+4nTMNS1bg6yBV; zwe63M#kF+Wor~?tt$Sa5>iIlU1dMr1Ph&z!%{imY zpMgINzGvGutTxrfNsHNLrai?PU>SvmJ#gBpp%G>1*>66dLgkaSiy2kC&(TvKT{EYn zA%OA7QDx?P5YARMG2%twSWF0{g5bK4UxK5Pk(>^*uy zXab(2SL$1ij>vc^rct8buh;GWf@NdRr-(>9{OSR8lQgf6SdQbj|MsQe3h2`b&vw^H z=lE@k3Kcgqu6<3vQuI;b1$P0sTa%v&9UojFm6c1K5jGC?yAj8Oq^-I&qh@AiG;x*ufHY-zT9;ctM1(`lx^F@?Te2ji_>ArXMH|}|y*|s| zAal+|UqPh_|NQta<1E-RMFZCa%G{Mw|MhL#Z6p&rgdcxx17g^9hec>>0_8p1p&?9i zvOx}{TA_-wYGV|{1H1%S43cM1&|YJKWb5_z6b?@YL2I);gqTY~+{^;X1EcpRYxl@d zIHo5GgUYQWE_PN>7-rBdhG`lce++RXp&+#i{bzws_6l4+`Nv>pSDnBv>=%Q;3na?C zCUf&-JCGOi8t+yS45zt&dlHCl{=vQ-dqQ9v#Q=L=ov%9%l}EYBWBakGih9}AtvQ6T z0T;+&5RhpQDUsGr98oQ!AWLzW5-4$^5IFpkl3l`$9+OTuFqWu9aeNvXtWDk2cM`T5 zHrnin*Vp2@zsQ}Upu%n%+8D(nxT_4!Bx+|$>3wA!tVz;9^Yq$gyz9dhG4FLZy`7ru z@&-j~VuQ=MG@rx`?WG^L1do~tfXjp!svGvkwbFWV-rv|P;1SLO9}_}Y$c^-7i(1Gr zj{#G`yw=^XBH3Yd%JSI-KDZHUo48;3Jm4wARyyLQ1mOHO);*mirYnvD&v>WU6r8qR ze?hjg6ki9=6f~C`vu_)^EBoP08tl4pHzFu}8H?Vmsk~fev&b<|LCCV9P9BS<+r(m3 z$XT>F1^O3e@i1&o@O9Epz)_w29Y@R-@d_-)+uj5MaI z)uzY@IKMwaP%VwL6h!e)Of+ae#r1{B_KfRa+dKX6Yy09be9f4t)|Q&$AFd@r%Ps~Kw{7!$55&&jd?~|4^^Yj1 zJkut+V|V0ebpamT6a9f8X}9fv5{u^m#$Xu&|1mOvHI6Ka@Q60|USU4RKvk+^Rrnr1 zf%X=M>22Ur-@)AopR4xrKKq{Fnz^_++n&&EeENm1`R(wuR;?0q#afgo3!NPDoD`Ss^w** zbxPyO{%4jwmZfCpN(hL1&ph##sOb^?;%_O3qnTY3m(!uhcHHB=&!bKlKjF7LW`Y|8 zuZFp)T9xv8WZmnLCNuGy51<(|vVkC5@3Z70N2BWM=(M(cyTDO06Xup?1QlM7S-p4C z_2%q`?e|?;*X4kj{i4Jby*P;LOy3>SyyR{Z$8x{HW}CLBbf+Nf-kP5p=BWTyY7p4@ z?cQv{zB5sr?!IL1z$LK9ykES)F#z5!7SeLBIxMa8UY|6jDu;K~(-@y>f&0K_)02Fg z)yC^PL)Syov$r~u65i^z{+8!$VoBE`K)f zc?a+-QhN~p?c2wgDAJGHtk-Ph7yt<`S+dURyteIFP-`|zmzf%H@SkBfb$8V+yE_E) zs8iIuEiIf=cL! zYj#Zr@bU^l<^&xR##ZWjy%%cj+~cKP46B~(%{5Z19iFV>)Y>P&y#sNjPsN^NX-w_ znGeX7Jb{-*>4* zZaGkz>-(sE`ZH=kBeH&aB3imoJwygK^_>v0A{{j}fE)=YSQPZpZe1$xv<*t-8Vw<8 zgh&=p^~WVfq_Uc`>%a26^oF90glfWbT)&j)^T&6{p|~U`AHEi~0(Dq-#-TO(qA}V{ z6IzLw_0`ipTu_}=?@+r;SA7=pzbeRfMwS7wt*^(g$jaB1uRM62S^#8S~ps+M%DuHk0Mbg$4YlBvv( z{_6(*FD1FJN(Q_cY*tW6E06K|LZUsb4zK!ED*0WLhW;1`mRCZqe{p^ZOOEByn$ zFcz&?Q-@UWNYf#2F7=-otK^G?GZ>jvkqu>Le$$;_ijC;&_!RYSc;1nTkWu}mk`;I` zraZ}GF))`pB`~){f<%Q!k6i^!aybmPO&8xw1NJO~LMVNx9q+Q7giprMTlsd)I~BZ8 zOgd)mD-@RI@@b!MN>qpxlTFFKTl)U!Xnix&`uO<;=a)~oJ4H%Qk)!*%x-yv?$PSsW zQ9DfzN$^(47%xCE#LIH6Bj^HO^~qe9C~W_AXo2=#LrNve2ZJK)1-HJg*-+Bwr$-y@ zCKLM`q*p~t@_Kn@M(B8FKwnxC*i7qn+|Gw2ve&<>!VebaHT9f7*FV*MkdD1i%t}e0 zB$a(G=geog0~pA!1KOzoBFKL<;w%L}$a%++1v?`;yJl%|-Sz(F{#;dUaxo!x>rz2D z{S6W+C<13WUONwP*2YW~X_A8@XZ9w77K{n{eqqOpYXv;@X{%Xgm3ZBO;&~vU_F80& zscLrzSss%-elF5vroP&|p84E#slhpRdxzk7E`FOY@?H5M5`uH+WeAxZNk80O(N zsRhmFDUZ!PStG}z4mGG~Bh+5@v}!M$%Ii6d*H#UbU&AWAt6oSCt{y z`zp*Ag5nJY#b2rlIht)P6CL3xeD>pAs8kPuN7QedZjkD~c+-KnSj&mLsj`V&tDG|V zS+TRG)D-8Ck2GCk!-8=2ZYZsK9G8wht!_nSu|7G-;*INDB;OT=|1V9=ERci_>?v!l z@ski5{QsId?`Sr|w~yBhr6_7t2(_xTHX&veRVDVQJ!93LwW%15y@}n@nl)8q9x19sTSN&`o(^^o}@6=`Q)L zWK69co1GflW6JrMRFEYXeFyMqyXH849>?0&4E|xUBz>K-ciNZj*iw2UB~exUn|L9k z)o04Og)M~YrgJZQYLX^Y><93Ayr6{T5-S%r-dg5{8YF^ye*S6sd9@0<{pw!y71(jd zU2T7*SC5Lh^_q846g%o+Ex~y$4^qli5aZ&Z5ADti_e2I=VIjLZLmy1?1e!<{A0`#) zGpP^L#w=!f$K<%ID_NUDML3E443|jB~;?D^QfjsvW(Ew+v_^8q5{Qyv%&f9 zM_^x5Q4)!Xrx^jaTGL?{Ly2FauLUb!H_+{^WHr=uFXB+7`_%HM$hW)W%Xf?i42uVF zS9CByhf&0&-}KoY(c@L_YYY1scTcRo30f8d^w*rKX}G!A>Az!@w#Rc$OJT|BmqVn1 z8pY8qTZNg$iz_7A0}wmoA5hbaatW}M(Wlt872Vbg?e~gB!_v1`3n8&{DE;v8Qf+uO9r0ZBhqdbeH8CskQG zUFknw%`M*-m3#i7o0;t73$MQ|yU@DSK^vac;I{DGFR~(rk+oa_@!Fspbn&8^|QZ zA7u#J_x++ZG2FUFz$6ORGp+9QDJ4T5VCQfH0QS}|-RjgM#AiNIg6Cb$4mVSDT8`#WBa&{5oTg($((RKhxJYV6x~kt(D;1b1Z50 z^>iz(?EveWEW3L2TAPzWfRmzQ?(SA%K2oVZ<;-z+o$>lm+W9N0&jibEKK2_`=H|U^ zIFlBxH)%EV?|#@#i@&4uUfSvOy|CJMnWBy19FuSn!zCeV#@Zi#5b(}sV@^+;DSRrd zpydzh5r#MLXYsuA@pDW}*r$+bWw|(rZ(kyZ>QlkOKbq?w)p$ub1NwTQyx~mlGU)=g zxD;|6@EjL3^>dYU_4I$GM$NTvk$jcngn~?$T9XU(RHNyhl=4-jA)|CM;3NnggHMVC z@xuzx6&xzwkuZFr;bL8Km4yl+_9C187QG-ge+cC5qtV-=ry$oWPE zT)qRZvi~;0(2HK*QSj7&DyT&m?Ce4Jt_kQ!{&&)8so-#N^`p~4cLJMZx!?hFwiqDO ztWcuMwr+@==x{WP?Y9{IwKR<2I`%niwLhpcMTIneV2z~={fkWqs#!-JdbM5}HJpqo z@9b(w+t?tww5=*-SAm><(Og?uU?JNVy#Kc5OUS@V!@vrcO@{25Sn&kk(;hlQ_V0qw zTQZEz#~5PHWvpv$fuPVS@N&+c@pK@2%|7Q_Bwkwc8f1s;FDdN@Tg1w_; zrW$$bKWMxIb5dxVH|0GH?^wR=sF^B2YVx)9A6?K3!Nr!AmZQbLEp_6?7P1iUPz(0( z-)3G|PRfg#TdLN&2QT~x8`;uI=&-G5j}k9qqtNzR|XP=NsDFm6fAV&{eKN6+)E#F8op}Rb!k%P=t=gfH`Rwt7fu@9vkR< zXkdAEG1b~KgrEb)|9HB82Ni3-@+~Mry1se%FLIvVtpCXwMhg31owC_&81{E8Wh$+J z8+Izl(DC{54O#!4&3KftXVLie0E-ott$we#xl1lSMv-sbJiHVm+}ML8F~%T=ZM%v&z^Bc+i?-^+)GEx^p>Ce44)An8O6yqmn;9$Lrb+W>-C_X9sl}mIF5VAXoWt1m5s6%XH_)T{jGa z^;Uimw$$I4cN;FdV0jI{{M+)k85iWL($$mFvR%pIO+RNgW$g9--voPWG2ftGn`wj? zM9HW(Eg?sk_2yI{X%qV!0~>%?l<5bQkfwRuLnzS~)8%m5VgD?VQ>8q(OGR{1tFCIu z#5v%Bu}+1)kOyT)FPGqh>&rj}CwSW-btM>#q5o?Wt$2WH9{Pn!o}}6KI3T#Lshsla)PMu@SXe z4Fuh<-n4UzL}jPH+GK$CFl1TBCi^fc{1mpI8jlRhjZv}q^9*@l+m+aVBJ4(cHRlOi zh;=ZMA|f#e6NqShf*5JAUhfq9m5nL|C3Q|F8rq-W7Y1Hji`$i<;R$|Q6|V>@;|*O_Cb7Q)4(!W zC~A3#JAZ);fGi^PyH;6@HbCi>;ooEt1v(lDIfkg-xLE@cG7C8;SBuh=3a~Nxmv7zZ z8I6&t(wXYMq;HzVnzY_}u4``(nJOFUT7Fh7rpcrbb*wDd_Yh)nQ6fB0p=PDfz zum|_6i^?}`>Umi*ZQ!07j64otmUy$Ls&-A6slh|#_Jn(xo$juCgcT|UYWfW1mBc!A z-GA%`k#+QS!$ujtv8qoJhD~>ntyI+F1f*jD!Bu|af*Qx7s(4f zHJ|XnlWk>_3yr;SL|*DWqCcYwK@I#NaaR4Q6Ur@k?X(jRnjI^LD8z5m*JrSm(0T zT0!w8rA6#9JQQWlCMFj3NolK;n7~3t@0A+@-|CkoA(`9&sRsSv2RacrzMGhuDnOH- zUtzqNK25|lm35=P7gr=HshlG=$vWfoJ~jA-vJ}&2vw3>SbHRBe6N(@+blG7&{AFH= zOb-v8;dk(}5J{~0_pxO?EEnC^+omKY6{JtlWTG}Q-~0i=<9*SkEE^xn+D%oR;G4Q= zkeXWtbjICOJe=3Glb`@@Dl%(b-WQXlvLM77X53bPDRleYeI-mx`Q+Us*j^Rd63?(IH=CaU+744go28Hy!Wc^@ofE82E}kfpQmiV#6L^sy zgQ)$S4-Hz^QW0^&s(L0DI8%DP8s%T(9$93Au98P?NKDROJ`acZ=wUPTe;uC9Ir8>H z^^JYVvTsluHN9zj!Jn!83rN<2O9fNBKXN)Z`yMV};0#V*4RpJdqz2vb8+3GgByQ|D zwO92kL~c|6C2yEn;6Fk>HDHb#|7%L@a8lzBukBJ*jYct$ta7@asfbefUi z6pwb;bd9Ft30fpR<&^_ur&EM^Mt>W+us_3O$T$;`rnnMD2P=2{YNOP;7DQUVEkFH| zg$_EzoD;Y7i8q3j z^|Q3<{TLusha;);y;NCk@|$fvGPdb1<_H;O;tF6gl;78-iwe_aK2It_K#AWAqvm34 zl%)K#xI^_8N9M6KHaBo+koD2d&EY6JfM|C!97Kv%;C z;E*xJlOObGAKhCz5P44&P}%Lx{xqm4^zC=WwK+?YD`Qu!hYRwzOjBRqzhWMMT~{AB zEsuIjnlD^5o-_+sPiHcba6XRtuv@QIW!hG#!sNW1|BJ|aI;%CAToTxyO`Zfu{q-}N zf$6y>hx?OVjyEZ#O-Wdfd?c3-vx7ifmDsqM4V|%aXo}kSDm9%AnR)|8maub?b|boS zHlxg3TrUltcO*tNbwsD+`}OXmqr`;`Kee0k4|Wvna262y?FBeV?`R}b@)4YWIW1Zl zAy!WY2WN384xu)sQ}+-Aj-+Lse1ea+ye9~jWbd8IIkaHt6G}$DA`jA;n*K>JO0|>@dZx22@-=)yoXRR@OXk&!7<$;f$@<<45k=&3? zCA5stmiO!OR5a~;x)*ko|D#!u5B%)fjheK@-~aM9;dKUZ6Y}qcwC(GE%NB%I9Li); zhmp|n@C5Lr@{8ure*11Y+&LHHMa2jdxvG+VJ!%Sv=3fXXAU_y%c7P0HhMOrBd}brO+Bx=pU|v+_6xT~2pzZ)c)& z=J%tyJhBX%s= zMzrNH`q}WB)=1BbH#)TgHDbEgW42c)d;bv~*?6`Ds@uo!{JT+>!fq)9kEr9yhVEuG-)?B;MkoyNR4@f~EkIA`&gmF3 zW?np1yQ>YQ2N9N)JYmmg*1os$?v0hMPBauujFjO4w&<=H#sAOHgC+%(skPEXi99(M| z6Gm2k)Io2CkAy6XgMj>*8(PWf--sj7rO|S@zj0bg@8PrG4d3>exsmPAZJBb_A1EsN zqPm<+fKtTDjC^eZ)@T))$U=~ywK+b(||1LGb#7ePKaiXTh=QbK|AsE8Lk>h;MUJRY7cK}Gp}Iq${7 z+v}u?@(z!&5S})ZZ2mmn02KKxCKOW&I$chUNn{Z!#2zdxczh)W@gijBme~X%6m**L zSAe2G>^lkZ{pjEEhSNG3yD_z?3O|Rsv$Ck# z)9BTD<*sDkaBSBAL@Jwcu@@mumdbB8mAIN*W~*8BaiyE|S;P!Ci>p^V1wdvZ_!R+R z^0FBposzb^by=clb{9bZTa!RwM6-vt-WUx|`zR`M3u=Nd!YX6Ia+R`&|6s<^jOhC& z6Y0yM=N%y?8Y@~(DBOC+56o%CMVr)R$aur=MC^o-+CpbVZ#M8XY_{) z+oDtqIZwq^$lMGp%ZRaB|MC3Zpp}|UMQXK-n{sW+=j9G(eHOkK{U4BAc$%QX`?HU! zeB@vj&h@o+7&>i_ABm~;9E;~2o35oyq*U_)@XhD1@97#HAfuO--!Wdu()1$sXkFWv zv72V+84X;(jm>V}TW&!5$hcj&K4cZTF^XesgNoPjGsmPnoK!>Ci4C>YKK@;mfP$wKC^ZRG3fbw zUz@Qg&h&L~uf(;SMHLI?KHK%mtKMAKT0_RB+6N&+2`y!5=mnpRY~2{hS-tD=b5Y&& z4w#>XL7B`u;&Iib-Bexj8ENA3lZ58)9#!ZN0;{f;4)Ukmaf}B}uZA){@XG?*>YXQJ zJ9j6^vw_}iHekfg&!vP7*YfjsVT$tbf3`{$r&t3SPh=Dl8%0|E=7`0{cxmrW!;&g$+s2H947(vJQ>nS2mD z$VY{b%l^!Anr%;1*`@!h73J2$*Z#iAuOVGIp#$NJ2=yHNAw=AAD5yY%9-sW9-Z+=} z$(p!DK61=v2X=-Y@>X`1iM8vwctBz}T9I72=-^D8d)@4x?~jBcquQ2?`*+p}H%tUEHiA6ubIxqSnP|4wm~Wv7Riubvzg$ z1dneOH}Q%6YPaw~2+!4lqI78!pD!+6g@kjDH#h7nEsvhsAv}jQI_9t-RRRWJ#R0{j zD3Qzq_4LHvS$=)J9{TgVDE%@%gG+3y!wBy|(-g1a!GgWhd0d4fzaK6W4r4d_aRVb7>kK8DpuDwI&Fj{kBRWG$vfNi6g$(!?v(OTF2osX$frV zT+;u(a)9W%W2Rnw0QU|*Y(WlzpByIXWO=DvSmgbRi54? zYITqpPz1z?Fr>dt?^aGE?%zFEk|LxOo-vFjjPgABB?8&$q)nF)65&YW6Gt8lP^=Z0 zqrDdVq1XRftq$)EaVW3lo%itv7kJ`J_kVP;=nF8TCYL^ms|p0~URc-=`;ioMv%#}N z)Xs%?sI{f6RyvQSoai1abiIub(4XWhc%p_Vq~|IWm`6TmF%RNT4YI|fkrypNzvi3f z7DeK2$ON7EFDvtwl7)5zhu02W2n4`Fx;Oc1%Zc$>;6+J3r2IfID|rGvCu@5Ug+^DT z+(r;l1K{~Kq#>}iP)Z3Qg3PV8PM_6eIuFepdc3&*+xxNL)hN-6KDKkpE962$xbHKz z7H4aEmJTKNsD0C7&U|aMn1cZYx+GK;NG9K|=-*zKPUi!?6-+}fOJn4f>pQxITZrsN zh=B)-7n6097`uZ*ubZoXr;Hb24{h>5O#Jtp2{Tz&s3r6|pLVEISf7R>dy1xKA6HxX z3K`US8sV1$CMyf6I^4*D2im5<1NTo>zXiha_Yv{ek_xIbhnzr@zy~p`!$mD(w(yu^ zJ{kdY+uCxDSIp1y9)fLxCpSIT)DF07bGRYR; zaN@N?mF4{IqsOl|-;gKho+SAtGWu}b!xamzzP^3Sp%S8Y&~xA4r53va*Qlh;tj8U& zu7#3*AY8+8t>e1ZTZDQ#;hW@{VY=sKQs5!39dFeZf?Ev?%z;Hpgadc_eC`pKpN*O; zguTLCf46?41<+@eYNCI`^rq4#9`mx*kY?2X>1)ogRkA`=;t-(E+A^lT2xvru{5x=z zRL~(K>GDK4$Vjx9=Cl`!Q+>%+YlSc1 zeE_IFhk}a1wf&B)z_ss7?N*?|eF7jQIqzJFg$N19^z5aP$KQ)H)2z?ftEnbh#mB7c zyPv`Rw|>NyGNx?_XDPhr>uHVhJ*hBpcbfzN_PVtY>drDQh~P1yEj))YgBE^HlZ4ut zAAUY)$j;o3HYv!e%9W)8;?bcVqwu3u&C)-FSlTa`wS(Cti{tIdx_2)=!C+@lE^mNc zx9k@eS!E@V*kmCO-wTbtoE!A~8cx~~@L-Laj?8&HvB(L7?q%idOI*Z;EOzK66w(8^ zD){oHhdsvw{ z?D1|y`MelUN=j)no@;L7deL2>`U3CCwUB^kjXbA>p#C?-R&1V9^EivHQs7iyLV5X@ z&t-Rd?`0cKsEK}ImMeI~z5+g0Q1fL*r+I&T!O832P}7p(SY6=;4#Q(w4{=~rUphWw z=2De+7<^pAtS7-NXy(_{?G5EQEPFw&i2ufnGu#YG??EPE+UlycH1E2;Rg@A+MuT!k z&VO}wZPrLU?|WEfOE(te{q1L{65x%%kp8u$jOM5y@9;YvIN@rh(q(YK1xbzBLB{z% z^?RK*XX%(s48b5>(A$fF&n>ToqXQn3Viv$`yR#w}jF13_bb(qyxsRYGMmct8N}Vxm zjUFyVUt^aWve(BQqn_RCzuvzi&->3bUykE%0?2W6vGlF+O2R22q#(YrVD%9S$O+yu zrP4XWi@E4Zy0U3^6o(w36P=Gs14KbSM&6k%Y^ zV3I|zr^@+ueIU`=U_`D;;m4w;Ok;hRlVpJkp`a3; z51z=fi>CUgpSJ0Lw?#GAVNA24nklfKO~c~*$@>4Ecx`=6_L>?LkxLfd*+c9Ma#fW{ zo(Tp?0OT>N~iqy%X2UbAx~P@32?@^ER8ny z14Kd@-e&~MdoA2AGQ;wilXDt@wpTyr)~?kyf~B6c;r5S-!S-M@^4%)IPVIf3=~Mo# zU;O&h2b18{>7vsfUPjlRzjJ?6Ryi_}6Bcbue>fo7&GxuF&IEt0(_aXnpW&&9 zbhtgGn)R62nrQJg92CcHyGM_tS*J(>ssnVrPHXkErv@v}?5e*seaurmZq&`p>-#z9 zer)I?>YwqxmL{NjFWYZHC3SEgF#6!cX42l+=|NfTxav&Gd4urvZd4-{L~v5|cau%* zQ=;nbayr9;gU~EG-*`ZZ75NKtIj6`cq%u58LuoV>TW@n7{`odkTHzo`Qz7K?E>tYY zNETVPMYXsvWNttEXVN1yCp-T~j$NUZw1D13p(GL)S?urI`h7iX#dyepoKqIkgazjw z5Y+k)H#Q*Sr#!^wleo{Go2TzpMp*7HzOT}yAo$(Z_O&~)TM-QR5^z@4hqMo^BURrTI=UkUb{m9v&1~;QRR} z(<5`L*;KgStLxO+Wr4naci*3_^^7oMPsL(E-!)Z;vADHawgsM#GR4f9Szo^zsRiwj z=XG|lR`%08^pWtU@j3K+E1`CD{&~Q`y}k1WASiTX8soNmE0u2TxDb-|*Eb4HAI^^% z==)&-rC*|AuOE~RUaQ_&^XT^)x6xD%&VmNN&(ZG?zOwtXR{6B6>6jwOZatl2kz+7( zP@4+=GN!R`5VW~DJYNPyOqaj0eeGp=Yx_C;@y1Y_m10bF*#B~L3-P$Dnfk|#lK$iC zpHmB1rG4@WZjQ>AePl`ZXb1D90t{-*QvlBm2p!)-^TdlY%Qu1G*4@*tP zQ`5~1h${(D-inWSqMzi|ZY)IL?j4;ub|Pq!@2I{r@%sor>*lHWQbuRx;Qyu}PB&n$ z!?T9aPSw`^_d6S|@46bfS^9sH9h?CFhx|b~?q3DIa{S(^dRydh5BpJ)SAA9f(lq3M E09Ph37ytkO literal 0 HcmV?d00001 From 39dd340a1a4f24c22106fde3c6da5d89bb59a91f Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 11 Feb 2025 16:10:30 -0800 Subject: [PATCH 069/115] Change TORCH_LIBRARY to TORCH_LIBRARY_FRAGMENT (#1645) * change TORCH_LIBRARY to TORCH_LIBRARY_FRAGMENT to prevent conflict between cpu/mps * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up --- .../op_linear_8bit_act_xbit_weight_aten.cpp | 2 +- torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp index 24d4008969..0307f05192 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp @@ -60,7 +60,7 @@ "_linear_8bit_act_" #weight_nbit "bit_weight", \ &linear_meta); -TORCH_LIBRARY(torchao, m) { +TORCH_LIBRARY_FRAGMENT(torchao, m) { DEFINE_OP(1); DEFINE_OP(2); DEFINE_OP(3); diff --git a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm index 162b5ab83c..2aeb7f4460 100644 --- a/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm +++ b/torchao/experimental/ops/mps/linear_fp_act_xbit_weight_aten.mm @@ -163,7 +163,7 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) { return B; } -TORCH_LIBRARY(torchao, m) { +TORCH_LIBRARY_FRAGMENT(torchao, m) { m.def("_pack_weight_1bit(Tensor W) -> Tensor"); m.def("_pack_weight_2bit(Tensor W) -> Tensor"); m.def("_pack_weight_3bit(Tensor W) -> Tensor"); From 682ffd5f3e5d0da636b9e12684c426bbd1eac2e0 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 11 Feb 2025 16:57:25 -0800 Subject: [PATCH 070/115] Update to cutlass 3.8 (#1634) --- third_party/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cutlass b/third_party/cutlass index b78588d163..e9627ce55b 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit b78588d1630aa6643bf021613717bafb705df4ef +Subproject commit e9627ce55b42fd2599f58cd4396da9380954def0 From aa514863fe0c4a778b35dcdb116804e7f0a79ad2 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Tue, 11 Feb 2025 17:35:11 -0800 Subject: [PATCH 071/115] SAM2: Collect p90 latency statistics (#1703) --- examples/sam2_amg_server/generate_data.py | 3 + examples/sam2_amg_server/result.csv | 140 +++++++++++----------- 2 files changed, 73 insertions(+), 70 deletions(-) diff --git a/examples/sam2_amg_server/generate_data.py b/examples/sam2_amg_server/generate_data.py index 8632f0163a..311a3825ec 100644 --- a/examples/sam2_amg_server/generate_data.py +++ b/examples/sam2_amg_server/generate_data.py @@ -60,6 +60,8 @@ def latencies_statistics(data): mean = np.mean(data_array) # Calculate the median median = np.median(data_array) + # Calculate the 90th percentile + p90 = np.percentile(data_array, 90) # Calculate the 95th percentile p95 = np.percentile(data_array, 95) # Calculate the 99th percentile @@ -74,6 +76,7 @@ def latencies_statistics(data): { "mean": mean, "median": median, + "p90": p90, "p95": p95, "p99": p99, "p999": p999, diff --git a/examples/sam2_amg_server/result.csv b/examples/sam2_amg_server/result.csv index 0327159727..86196ac981 100644 --- a/examples/sam2_amg_server/result.csv +++ b/examples/sam2_amg_server/result.csv @@ -1,70 +1,70 @@ -furious,fast,points-per-batch,bytes,argmax,p95,p999,p99,miou,fourth,total_time,torch_version,total_img_s,batch-size,second,experiment_name,run_script_time,mean,batch_size,percentage,third,task,num-images,fifth,environ,fail_count,allow-recompiles,max,load-exported-model,torchvision_version,median,total_ms_per_img,gpu-preproc,meta-folder,bytes_MiB,first,baseline,export-model -,,64,4561654784,468,1323ms,2363ms,2086ms,,892ms,927.4758312702179s,2.7.0.dev20250201+cu124,1.0781952114379705img/s,,1046ms,baseline_amg,931.3759133815765,921ms,1,4,955ms,amg,,724ms,None,,,2466ms,,0.22.0.dev20250201+cu124,869ms,927.4758312702179ms,,,4350,1733ms,None, -,,64,4205527040,0,815ms,904ms,857ms,1.0,660ms,718.6690595149994s,2.7.0.dev20250201+cu124,1.3914610442181266img/s,,748ms,amg_ao,723.3117945194244,713ms,1,4,673ms,amg,,760ms,None,0.0,,1263ms,,0.22.0.dev20250201+cu124,697ms,718.6690595149994ms,,,4010,1263ms,, -,,1024,35427762688,109,745ms,1006ms,791ms,0.9999994533658028,577ms,631.6344785690308s,2.7.0.dev20250201+cu124,1.5831941319376708img/s,1,619ms,amg_ao_ppb_1024_basic,635.8103907108307,626ms,1,34,594ms,amg,,609ms,None,0.0,,1947ms,,0.22.0.dev20250201+cu124,610ms,631.6344785690308ms,,,33786,1005ms,, -,None,1024,30775568896,0,576ms,3526ms,644ms,,501ms,849.2408077716827s,2.7.0.dev20250201+cu124,1.1775223126923131img/s,1,3157ms,amg_ao_ppb_1024_fast_cold,861.5647690296173,841ms,1,30,421ms,amg,,501ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_inductor_cache_dir'},,,372124ms,,0.22.0.dev20250201+cu124,466ms,849.2408077716827ms,,,29349,372124ms,, -,None,1024,30775568896,0,541ms,1512ms,617ms,0.9937346105006776,386ms,452.082448720932s,2.7.0.dev20250201+cu124,2.2119858951155487img/s,1,1000ms,amg_ao_ppb_1024_fast,458.1768579483032,446ms,1,30,448ms,amg,,392ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_inductor_cache_dir'},191.0,,8411ms,,0.22.0.dev20250201+cu124,422ms,452.082448720932ms,,,29349,8411ms,, -,,1024,18221665280,,,,,,,356.0369083881378s,2.7.0.dev20250201+cu124,0.0img/s,1,,amg_ao_ppb_1024_save_export,367.34787678718567,,1,17,,amg,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,,17377,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast -,,1024,49836364288,837,559ms,1592ms,639ms,0.993709121615135,397ms,460.2203013896942s,2.7.0.dev20250201+cu124,2.1728724199701137img/s,1,493ms,amg_ao_ppb_1024_load_export_cold,464.4886541366577,453ms,1,48,443ms,amg,,510ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_inductor_cache_dir'},188.0,,1760ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,436ms,460.2203013896942ms,,,47527,961ms,, -,,1024,49836364288,837,592ms,1691ms,649ms,0.993709121615135,445ms,478.4169816970825s,2.7.0.dev20250201+cu124,2.09022680685939img/s,1,431ms,amg_ao_ppb_1024_load_export,483.0541400909424,472ms,1,48,429ms,amg,,508ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_inductor_cache_dir'},188.0,,1737ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,462ms,478.4169816970825ms,,,47527,763ms,, -,,1024,49861530112,837,565ms,1670ms,622ms,0.9937652501226203,398ms,465.69065976142883s,2.7.0.dev20250201+cu124,2.1473482000096276img/s,1,435ms,amg_ao_ppb_1024_load_export_gpu_preproc,469.74300265312195,460ms,1,48,427ms,amg,,397ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_inductor_cache_dir'},185.0,,1735ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,452ms,465.69065976142883ms,None,,47551,776ms,, -,None,1024,49836364288,837,546ms,1611ms,608ms,0.993709121615135,415ms,454.15750002861023s,2.7.0.dev20250201+cu124,2.201879303847242img/s,1,438ms,amg_ao_ppb_1024_fast_export_cold,458.17887783050537,448ms,1,48,545ms,amg,,421ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_inductor_cache_dir'},188.0,,1730ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,430ms,454.15750002861023ms,,,47527,943ms,, -,None,1024,49836364288,837,577ms,1702ms,643ms,0.993709121615135,402ms,473.2662968635559s,2.7.0.dev20250201+cu124,2.112975309307316img/s,1,432ms,amg_ao_ppb_1024_fast_export,477.25709891319275,467ms,1,48,427ms,amg,,486ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_inductor_cache_dir'},188.0,,1742ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,451ms,473.2662968635559ms,,,47527,754ms,, -,None,1024,49861530112,837,543ms,1597ms,596ms,0.9937652501226203,396ms,450.6334979534149s,2.7.0.dev20250201+cu124,2.219098235132482img/s,1,433ms,amg_ao_ppb_1024_fast_export_gpu_preproc,454.61152243614197,445ms,1,48,426ms,amg,,395ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_inductor_cache_dir'},185.0,,1766ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast,0.22.0.dev20250201+cu124,430ms,450.6334979534149ms,None,,47551,764ms,, -None,None,1024,29712131072,0,275ms,2880ms,333ms,0.9736336072679046,169ms,994.9303135871887s,2.7.0.dev20250201+cu124,1.0050955190967423img/s,1,2081ms,amg_ao_ppb_1024_fast_furious_cold,1006.4958641529083,987ms,1,29,192ms,amg,,143ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_furious_inductor_cache_dir'},305.0,,800771ms,,0.22.0.dev20250201+cu124,174ms,994.9303135871887ms,,,28335,800771ms,, -None,None,1024,29712131072,0,274ms,933ms,334ms,0.9736336072679046,163ms,192.62348794937134s,2.7.0.dev20250201+cu124,5.191474885258216img/s,1,699ms,amg_ao_ppb_1024_fast_furious,198.63731622695923,186ms,1,29,179ms,amg,,130ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_furious_inductor_cache_dir'},305.0,,10094ms,,0.22.0.dev20250201+cu124,165ms,192.62348794937134ms,,,28335,10094ms,, -None,,1024,9179703808,,,,,,,519.6249597072601s,2.7.0.dev20250201+cu124,0.0img/s,1,,amg_ao_ppb_1024_save_export_furious,529.3503592014313,,1,8,,amg,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_furious_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,,8754,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious -None,,1024,29307644416,468,259ms,906ms,309ms,0.971583874842335,166ms,178.88770842552185s,2.7.0.dev20250201+cu124,5.590099000101732img/s,1,202ms,amg_ao_ppb_1024_load_export_furious_cold,183.20707321166992,169ms,1,28,198ms,amg,,169ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_furious_inductor_cache_dir'},308.0,,1468ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,158ms,178.88770842552185ms,,,27949,906ms,, -None,,1024,29307644416,468,258ms,716ms,299ms,0.971583874842335,167ms,173.60630631446838s,2.7.0.dev20250201+cu124,5.760159416033033img/s,1,164ms,amg_ao_ppb_1024_load_export_furious,177.37090826034546,168ms,1,28,156ms,amg,,125ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_furious_inductor_cache_dir'},308.0,,1468ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,157ms,173.60630631446838ms,,,27949,716ms,, -None,,1024,29308632576,468,232ms,679ms,282ms,0.9707489542138409,126ms,156.5510959625244s,2.7.0.dev20250201+cu124,6.387690829321198img/s,1,160ms,amg_ao_ppb_1024_load_export_furious_gpu_preproc,160.46401953697205,151ms,1,28,155ms,amg,,126ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_load_export_furious_inductor_cache_dir'},290.0,,1467ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,136ms,156.5510959625244ms,None,,27950,678ms,, -None,None,1024,29307644416,468,268ms,750ms,320ms,0.971583874842335,159ms,182.61804270744324s,2.7.0.dev20250201+cu124,5.4759101848551435img/s,1,162ms,amg_ao_ppb_1024_fast_export_furious_cold,187.25734424591064,177ms,1,28,158ms,amg,,149ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_furious_inductor_cache_dir'},308.0,,1466ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,165ms,182.61804270744324ms,,,27949,750ms,, -None,None,1024,29307644416,468,259ms,700ms,308ms,0.971583874842335,134ms,178.3385353088379s,2.7.0.dev20250201+cu124,5.607313070437913img/s,1,160ms,amg_ao_ppb_1024_fast_export_furious,182.3735547065735,173ms,1,28,157ms,amg,,162ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_furious_inductor_cache_dir'},308.0,,1507ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,163ms,178.3385353088379ms,,,27949,700ms,, -None,None,1024,16525926912,0,201ms,36421ms,227ms,0.9716291864482343,141ms,245.76354837417603s,2.7.0.dev20250201+cu124,4.068951667630937img/s,1,137ms,amg_ao_ppb_1024_fast_export_furious_recompiles,251.90375113487244,240ms,1,16,131ms,amg,,128ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_furious_inductor_cache_dir'},311.0,None,49208ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,140ms,245.76354837417603ms,,,15760,49208ms,, -None,None,1024,29308632576,468,233ms,774ms,283ms,0.9707489542138409,127ms,157.9279761314392s,2.7.0.dev20250201+cu124,6.3320003491194425img/s,1,163ms,amg_ao_ppb_1024_fast_export_furious_gpu_preproc,162.7095422744751,152ms,1,28,157ms,amg,,129ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_furious_inductor_cache_dir'},290.0,,1464ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,137ms,157.9279761314392ms,None,,27950,773ms,, -None,None,1024,16551092736,0,174ms,308ms,203ms,0.9708677416053486,115ms,137.26364755630493s,2.7.0.dev20250201+cu124,7.28525008480344img/s,1,135ms,amg_ao_ppb_1024_fast_export_furious_gpu_preproc_recompiles,142.44125938415527,130ms,1,16,135ms,amg,,116ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_fast_export_furious_inductor_cache_dir'},293.0,None,2189ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/amg_ao_fast_furious,0.22.0.dev20250201+cu124,121ms,137.26364755630493ms,None,,15784,2189ms,, -,,1,1402492416,0,214ms,316ms,281ms,,100ms,136.17227387428284s,2.7.0.dev20250201+cu124,7.343638844741783img/s,,118ms,baseline_sps,140.2417643070221,131ms,1,1,105ms,sps,,227ms,None,,,532ms,,0.22.0.dev20250201+cu124,115ms,136.17227387428284ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1337,532ms,None, -,,1,1404942848,0,205ms,229ms,219ms,1.0,105ms,127.24607348442078s,2.7.0.dev20250201+cu124,7.858788665274091img/s,,105ms,sps_ao,131.5206482410431,122ms,1,1,102ms,sps,,225ms,None,0.0,,579ms,,0.22.0.dev20250201+cu124,110ms,127.24607348442076ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1339,579ms,, -,,1,1404989952,0,203ms,256ms,218ms,1.0,106ms,124.8940806388855s,2.7.0.dev20250201+cu124,8.006784588065194img/s,1,104ms,sps_ao_ppb_1_basic,128.7957148551941,120ms,1,1,102ms,sps,,217ms,None,0.0,,583ms,,0.22.0.dev20250201+cu124,109ms,124.8940806388855ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1339,583ms,, -,None,1,1408784896,0,216ms,3260ms,223ms,,201ms,488.7042841911316s,2.7.0.dev20250201+cu124,2.046227201906217img/s,1,2959ms,sps_ao_ppb_1_fast_cold,496.82423877716064,483ms,1,1,212ms,sps,,209ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_inductor_cache_dir'},,,304090ms,,0.22.0.dev20250201+cu124,203ms,488.7042841911316ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1343,304090ms,, -,None,1,1366200320,0,217ms,775ms,222ms,0.9998691322207451,122ms,196.3028929233551s,2.7.0.dev20250201+cu124,5.0941684307752img/s,1,768ms,sps_ao_ppb_1_fast,202.54180693626404,189ms,1,1,195ms,sps,,208ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_inductor_cache_dir'},0.0,,8209ms,,0.22.0.dev20250201+cu124,205ms,196.3028929233551ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1302,8209ms,, -,,1,1390578176,,,,,,,307.4514627456665s,2.7.0.dev20250201+cu124,0.0img/s,1,,sps_ao_ppb_1_save_export,316.7780604362488,,1,1,,sps,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1326,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast -,,1,6238665728,0,215ms,233ms,221ms,0.9998687437176704,202ms,160.5826907157898s,2.7.0.dev20250201+cu124,6.227321235822784img/s,1,221ms,sps_ao_ppb_1_load_export_cold,165.16510462760925,153ms,1,6,198ms,sps,,214ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_inductor_cache_dir'},0.0,,576ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,138ms,160.5826907157898ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5949,576ms,, -,,1,6238665728,0,213ms,294ms,220ms,0.9998687437176704,210ms,130.84592247009277s,2.7.0.dev20250201+cu124,7.642576712534304img/s,1,108ms,sps_ao_ppb_1_load_export,135.52789616584778,125ms,1,6,144ms,sps,,140ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_inductor_cache_dir'},0.0,,434ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,104ms,130.84592247009277ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5949,434ms,, -,,1,6261886976,0,165ms,180ms,175ms,0.999868236720562,100ms,118.1360731124878s,2.7.0.dev20250201+cu124,8.46481496847971img/s,1,103ms,sps_ao_ppb_1_load_export_gpu_preproc,122.45444965362549,112ms,1,6,103ms,sps,,98ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_inductor_cache_dir'},0.0,,488ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,103ms,118.1360731124878ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5971,488ms,, -,None,1,6238665728,0,206ms,226ms,216ms,0.9998687437176704,92ms,124.29203748703003s,2.7.0.dev20250201+cu124,8.045567682518286img/s,1,121ms,sps_ao_ppb_1_fast_export_cold,128.70573449134827,118ms,1,6,135ms,sps,,96ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_inductor_cache_dir'},0.0,,430ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,104ms,124.29203748703003ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5949,430ms,, -,None,1,6238665728,0,200ms,226ms,216ms,0.9998687437176704,99ms,121.70427465438843s,2.7.0.dev20250201+cu124,8.216638263855277img/s,1,99ms,sps_ao_ppb_1_fast_export,126.40637016296387,115ms,1,6,96ms,sps,,105ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_inductor_cache_dir'},0.0,,474ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,103ms,121.70427465438843ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5949,474ms,, -,None,1,6261886976,0,168ms,189ms,178ms,0.999868236720562,93ms,122.82635688781738s,2.7.0.dev20250201+cu124,8.141575027852884img/s,1,107ms,sps_ao_ppb_1_fast_export_gpu_preproc,127.55544590950012,117ms,1,6,98ms,sps,,172ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_inductor_cache_dir'},0.0,,481ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast,0.22.0.dev20250201+cu124,104ms,122.82635688781738ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,5971,481ms,, -None,None,1,903450624,0,66ms,2448ms,71ms,0.9996802344322204,18ms,598.2366213798523s,2.7.0.dev20250201+cu124,1.6715793788977134img/s,1,1896ms,sps_ao_ppb_1_fast_furious_cold,606.6854190826416,590ms,1,0,24ms,sps,,30ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_furious_inductor_cache_dir'},0.0,,553957ms,,0.22.0.dev20250201+cu124,30ms,598.2366213798523ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,861,553957ms,, -None,None,1,903450624,0,60ms,922ms,68ms,0.9996802344322204,19ms,46.42959976196289s,2.7.0.dev20250201+cu124,21.537984499690705img/s,1,914ms,sps_ao_ppb_1_fast_furious,52.85066604614258,40ms,1,0,27ms,sps,,52ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_furious_inductor_cache_dir'},0.0,,8831ms,,0.22.0.dev20250201+cu124,28ms,46.42959976196289ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,861,8831ms,, -None,,1,903450624,,,,,,,395.61680269241333s,2.7.0.dev20250201+cu124,0.0img/s,1,,sps_ao_ppb_1_save_export_furious,405.58058881759644,,1,0,,sps,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_furious_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,861,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious -None,,1,1768025088,0,63ms,78ms,70ms,0.9996752961277962,31ms,40.04996109008789s,2.7.0.dev20250201+cu124,24.968813271768536img/s,1,41ms,sps_ao_ppb_1_load_export_furious_cold,44.494996547698975,33ms,1,1,54ms,sps,,58ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_furious_inductor_cache_dir'},0.0,,688ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,29ms,40.04996109008789ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1686,688ms,, -None,,1,1768025088,0,67ms,98ms,73ms,0.9996752961277962,54ms,41.31868815422058s,2.7.0.dev20250201+cu124,24.20212365570597img/s,1,24ms,sps_ao_ppb_1_load_export_furious,45.522459983825684,36ms,1,1,24ms,sps,,24ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_furious_inductor_cache_dir'},0.0,,769ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,31ms,41.31868815422058ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1686,769ms,, -None,,1,1794153472,0,28ms,33ms,30ms,0.9996936089992523,18ms,30.337790489196777s,2.7.0.dev20250201+cu124,32.96218952913192img/s,1,21ms,sps_ao_ppb_1_load_export_furious_gpu_preproc,35.1632604598999,22ms,1,1,22ms,sps,,22ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_load_export_furious_inductor_cache_dir'},0.0,,720ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,20ms,30.337790489196777ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1711,720ms,, -None,None,1,1768025088,0,59ms,82ms,69ms,0.9996752961277962,37ms,36.78891086578369s,2.7.0.dev20250201+cu124,27.182103967368906img/s,1,39ms,sps_ao_ppb_1_fast_export_furious_cold,40.70477890968323,31ms,1,1,53ms,sps,,35ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_furious_inductor_cache_dir'},0.0,,752ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,28ms,36.78891086578369ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1686,752ms,, -None,None,1,1768025088,0,62ms,74ms,69ms,0.9996752961277962,45ms,37.20629072189331s,2.7.0.dev20250201+cu124,26.877175353886315img/s,1,39ms,sps_ao_ppb_1_fast_export_furious,41.312560081481934,32ms,1,1,22ms,sps,,23ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_furious_inductor_cache_dir'},0.0,,678ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,29ms,37.20629072189331ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1686,678ms,, -None,None,1,1768025088,0,58ms,82ms,68ms,0.24502152660781712,19ms,44.12568783760071s,2.7.0.dev20250201+cu124,22.662536246015694img/s,1,62ms,sps_ao_ppb_1_fast_export_furious_recompiles,49.61470317840576,38ms,1,1,22ms,sps,,23ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_furious_inductor_cache_dir'},0.0,None,8124ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,28ms,44.12568783760071ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1686,8124ms,, -None,None,1,1794153472,0,26ms,29ms,27ms,0.9996936089992523,16ms,25.35749101638794s,2.7.0.dev20250201+cu124,39.436078252131644img/s,1,20ms,sps_ao_ppb_1_fast_export_furious_gpu_preproc,29.401476621627808,20ms,1,1,20ms,sps,,21ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_furious_inductor_cache_dir'},0.0,,662ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,19ms,25.35749101638794ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1711,662ms,, -None,None,1,1794153472,0,26ms,31ms,27ms,0.22546337781244644,17ms,26.919757604599s,2.7.0.dev20250201+cu124,37.14743701218019img/s,1,21ms,sps_ao_ppb_1_fast_export_furious_gpu_preproc_recompiles,32.35977077484131,22ms,1,1,20ms,sps,,21ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/sps_fast_export_furious_inductor_cache_dir'},0.0,None,2134ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/sps_ao_fast_furious,0.22.0.dev20250201+cu124,19ms,26.919757604599ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1711,2134ms,, -,,,1402492416,126,775ms,1593ms,1171ms,,150ms,331.5782699584961s,2.7.0.dev20250201+cu124,3.0158791772608344img/s,,289ms,baseline_mps,335.87450075149536,324ms,1,1,304ms,mps,,541ms,None,,,1991ms,,0.22.0.dev20250201+cu124,258ms,331.5782699584961ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1337,611ms,None, -,,,8411175424,0,227ms,311ms,239ms,0.999999164044857,105ms,143.97097539901733s,2.7.0.dev20250201+cu124,6.945844446969173img/s,,127ms,mps_ao,148.60355854034424,137ms,1,8,117ms,mps,,127ms,None,0.0,,634ms,,0.22.0.dev20250201+cu124,122ms,143.97097539901733ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,8021,634ms,, -,,,8411175424,0,234ms,309ms,259ms,0.999999164044857,221ms,164.95788407325745s,2.7.0.dev20250201+cu124,6.062153413388245img/s,1,234ms,mps_ao_ppb_None_basic,168.8498158454895,158ms,1,8,231ms,mps,,242ms,None,0.0,,644ms,,0.22.0.dev20250201+cu124,135ms,164.95788407325745ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,8021,644ms,, -,None,,8411176448,0,220ms,54779ms,243ms,,209ms,568.1692686080933s,2.7.0.dev20250201+cu124,1.7600388744181994img/s,1,1564ms,mps_ao_ppb_None_fast_cold,577.6140518188477,561ms,1,8,130ms,mps,,214ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_inductor_cache_dir'},,,332350ms,,0.22.0.dev20250201+cu124,115ms,568.1692686080933ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,8021,332350ms,, -,None,,8411176448,0,221ms,1345ms,240ms,0.9983834705352783,97ms,165.37928342819214s,2.7.0.dev20250201+cu124,6.0467065721336315img/s,1,580ms,mps_ao_ppb_None_fast,170.9393391609192,155ms,1,8,109ms,mps,,144ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_inductor_cache_dir'},0.0,,9522ms,,0.22.0.dev20250201+cu124,126ms,165.37928342819214ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,8021,9522ms,, -,,,1390578176,,,,,,,206.4340798854828s,2.7.0.dev20250201+cu124,0.0img/s,1,,mps_ao_ppb_None_save_export,217.42104578018188,,1,1,,mps,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,1326,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast -,,,7556661248,0,218ms,322ms,236ms,0.998383426964283,104ms,138.59291863441467s,2.7.0.dev20250201+cu124,7.215375863739731img/s,1,116ms,mps_ao_ppb_None_load_export_cold,143.01005744934082,131ms,1,7,112ms,mps,,122ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_inductor_cache_dir'},0.0,,579ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,115ms,138.59291863441467ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7206,579ms,, -,,,7556661248,0,218ms,258ms,237ms,0.998383426964283,97ms,136.831298828125s,2.7.0.dev20250201+cu124,7.308269442476818img/s,1,116ms,mps_ao_ppb_None_load_export,141.67460775375366,129ms,1,7,111ms,mps,,120ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_inductor_cache_dir'},0.0,,589ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,114ms,136.831298828125ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7206,589ms,, -,,,7581827072,0,190ms,374ms,216ms,0.9984678273200989,170ms,149.05044078826904s,2.7.0.dev20250201+cu124,6.70913815961492img/s,1,187ms,mps_ao_ppb_None_load_export_gpu_preproc,153.32005190849304,142ms,1,7,181ms,mps,,143ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_inductor_cache_dir'},0.0,,596ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,135ms,149.05044078826904ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7230,596ms,, -,None,,7556661248,0,208ms,54466ms,226ms,0.9983833708167076,188ms,287.1738612651825s,2.7.0.dev20250201+cu124,3.482211074484173img/s,1,131ms,mps_ao_ppb_None_fast_export_cold,295.3504989147186,278ms,1,7,108ms,mps,,140ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_inductor_cache_dir'},0.0,,62539ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,109ms,287.1738612651825ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7206,62539ms,, -,None,,7556661248,0,218ms,1720ms,230ms,0.9983833900690079,195ms,141.05165219306946s,2.7.0.dev20250201+cu124,7.089601464796843img/s,1,230ms,mps_ao_ppb_None_fast_export,147.43897795677185,133ms,1,7,216ms,mps,,222ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_inductor_cache_dir'},0.0,,3561ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,111ms,141.05165219306946ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7206,3561ms,, -,None,,7581827072,0,185ms,1572ms,197ms,0.9984678581357003,94ms,148.53872227668762s,2.7.0.dev20250201+cu124,6.73225125861302img/s,1,107ms,mps_ao_ppb_None_fast_export_gpu_preproc,154.97156023979187,141ms,1,7,105ms,mps,,112ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_inductor_cache_dir'},0.0,,4246ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast,0.22.0.dev20250201+cu124,127ms,148.53872227668762ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,7230,4246ms,, -None,None,,4427842560,0,74ms,63302ms,84ms,0.9964296479523181,22ms,723.8993864059448s,2.7.0.dev20250201+cu124,1.3814074424967462img/s,1,1071ms,mps_ao_ppb_None_fast_furious_cold,733.4108500480652,716ms,1,4,29ms,mps,,37ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_furious_inductor_cache_dir'},0.0,,581345ms,,0.22.0.dev20250201+cu124,49ms,723.8993864059448ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,4222,581345ms,, -None,None,,4427842560,0,74ms,1300ms,85ms,0.9964293534457683,20ms,58.8767945766449s,2.7.0.dev20250201+cu124,16.9846202937936img/s,1,350ms,mps_ao_ppb_None_fast_furious,64.73449230194092,51ms,1,4,29ms,mps,,30ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_furious_inductor_cache_dir'},0.0,,8402ms,,0.22.0.dev20250201+cu124,34ms,58.8767945766449ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,4222,8402ms,, -None,,,903450624,,,,,,,315.72570967674255s,2.7.0.dev20250201+cu124,0.0img/s,1,,mps_ao_ppb_None_save_export_furious,324.74191069602966,,1,0,,mps,0,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_furious_inductor_cache_dir'},,,,,0.22.0.dev20250201+cu124,,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,861,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious -None,,,3998911488,0,82ms,301ms,90ms,0.9955771351754665,41ms,57.82986092567444s,2.7.0.dev20250201+cu124,17.292104528579888img/s,1,38ms,mps_ao_ppb_None_load_export_furious_cold,62.62674617767334,51ms,1,3,37ms,mps,,40ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_furious_inductor_cache_dir'},0.0,,754ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,46ms,57.82986092567444ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3813,754ms,, -None,,,3998911488,0,88ms,252ms,97ms,0.9955771351754665,32ms,65.55874681472778s,2.7.0.dev20250201+cu124,15.25349474458456img/s,1,80ms,mps_ao_ppb_None_load_export_furious,70.35485363006592,58ms,1,3,39ms,mps,,40ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_furious_inductor_cache_dir'},0.0,,875ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,53ms,65.55874681472778ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3813,875ms,, -None,,,4024077312,0,45ms,285ms,56ms,0.9959434471726417,29ms,41.67199182510376s,2.7.0.dev20250201+cu124,23.996933100701625img/s,1,35ms,mps_ao_ppb_None_load_export_furious_gpu_preproc,46.09472918510437,35ms,1,3,35ms,mps,,36ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_load_export_furious_inductor_cache_dir'},0.0,,653ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,32ms,41.67199182510376ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3837,653ms,, -None,None,,3998911488,0,68ms,51237ms,77ms,0.9966195167303086,20ms,211.8625111579895s,2.7.0.dev20250201+cu124,4.720042231795708img/s,1,27ms,mps_ao_ppb_None_fast_export_furious_cold,218.6763949394226,204ms,1,3,30ms,mps,,66ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_furious_inductor_cache_dir'},0.0,,79408ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,32ms,211.8625111579895ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3813,79408ms,, -None,None,,3998911488,0,70ms,1746ms,78ms,0.9966195802688599,59ms,51.70280361175537s,2.7.0.dev20250201+cu124,19.341310918246524img/s,1,43ms,mps_ao_ppb_None_fast_export_furious,57.28682208061218,44ms,1,3,34ms,mps,,70ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_furious_inductor_cache_dir'},0.0,,3842ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,35ms,51.70280361175537ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3813,3842ms,, -None,None,,3998911488,0,65ms,6664ms,75ms,0.9956195802688599,20ms,59.52086091041565s,2.7.0.dev20250201+cu124,16.8008322578716img/s,1,56ms,mps_ao_ppb_None_fast_export_furious_recompiles,64.74269723892212,52ms,1,3,27ms,mps,,29ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_furious_inductor_cache_dir'},0.0,None,11728ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,30ms,59.52086091041565ms,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3813,11728ms,, -None,None,,4024077312,0,37ms,1743ms,46ms,0.9960403459072114,19ms,37.689289808273315s,2.7.0.dev20250201+cu124,26.5327366232432img/s,1,26ms,mps_ao_ppb_None_fast_export_furious_gpu_preproc,42.8827166557312,31ms,1,3,27ms,mps,,30ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_furious_inductor_cache_dir'},0.0,,3914ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,23ms,37.689289808273315ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3837,3914ms,, -None,None,,4024077312,0,35ms,1672ms,43ms,0.9950685520768165,22ms,44.08118724822998s,2.7.0.dev20250201+cu124,22.685414400678457img/s,1,26ms,mps_ao_ppb_None_fast_export_furious_gpu_preproc_recompiles,50.419389486312866,36ms,1,3,26ms,mps,,31ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/mps_fast_export_furious_inductor_cache_dir'},0.0,None,9520ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/exported_models/mps_ao_fast_furious,0.22.0.dev20250201+cu124,23ms,44.08118724822998ms,None,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_20/amg_baseline_annotations,3837,9520ms,, +torchvision_version,load-exported-model,p99,environ,miou,second,experiment_name,max,argmax,fast,gpu-preproc,allow-recompiles,mean,total_img_s,fail_count,furious,total_time,baseline,first,third,fifth,task,median,p999,meta-folder,export-model,percentage,batch_size,p90,points-per-batch,total_ms_per_img,p95,run_script_time,torch_version,fourth,num-images,bytes_MiB,batch-size,bytes +0.22.0.dev20250201+cu124,,2080ms,None,,991ms,baseline_amg,2489ms,222,,,,918ms,1.0819226362225578img/s,,,924.2805044651031s,None,1786ms,1050ms,865ms,amg,864ms,2313ms,,,4,1,1144ms,64,924.2805044651031ms,1310ms,928.9303262233734,2.7.0.dev20250201+cu124,993ms,,4350,,4561654784 +0.22.0.dev20250201+cu124,,852ms,None,1.0,790ms,amg_ao,1290ms,0,,,,709ms,1.3988966237833114img/s,0.0,,714.8491053581238s,,1290ms,783ms,766ms,amg,693ms,919ms,,,4,1,786ms,64,714.8491053581238ms,807ms,719.3047206401825,2.7.0.dev20250201+cu124,772ms,,4010,,4205527040 +0.22.0.dev20250201+cu124,,789ms,None,0.9999994533658028,716ms,amg_ao_ppb_1024_basic,2050ms,109,,,,628ms,1.5792527251617097img/s,0.0,,633.2108750343323s,,1125ms,710ms,563ms,amg,613ms,1126ms,,,34,1,706ms,1024,633.2108750343323ms,737ms,637.3582756519318,2.7.0.dev20250201+cu124,581ms,,33786,1,35427762688 +0.22.0.dev20250201+cu124,,601ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_inductor_cache_dir'},,3132ms,amg_ao_ppb_1024_fast_cold,404866ms,0,None,,,845ms,1.1731452112213954img/s,,,852.4093952178955s,,404866ms,511ms,395ms,amg,423ms,3534ms,,,30,1,513ms,1024,852.4093952178955ms,545ms,862.1773693561554,2.7.0.dev20250201+cu124,411ms,,29349,1,30775568896 +0.22.0.dev20250201+cu124,,631ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_inductor_cache_dir'},0.9937397953251312,969ms,amg_ao_ppb_1024_fast,8544ms,0,None,,,460ms,2.1431006621923343img/s,188.0,,466.6136395931244s,,8544ms,466ms,384ms,amg,439ms,1389ms,,,30,1,530ms,1024,466.6136395931244ms,562ms,471.85974502563477,2.7.0.dev20250201+cu124,385ms,,29349,1,30775568896 +0.22.0.dev20250201+cu124,,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_inductor_cache_dir'},,,amg_ao_ppb_1024_save_export,,,,,,,0.0img/s,,,336.7823131084442s,,,,,amg,,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast,17,1,,1024,,,346.5574824810028,2.7.0.dev20250201+cu124,,0,17377,1,18221665280 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast,651ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_load_export_inductor_cache_dir'},0.9937755975720319,431ms,amg_ao_ppb_1024_load_export_cold,1609ms,10,,,,464ms,2.124166095058908img/s,191.0,,470.7729787826538s,,774ms,428ms,509ms,amg,445ms,1593ms,,,48,1,542ms,1024,470.7729787826538ms,573ms,475.0158474445343,2.7.0.dev20250201+cu124,400ms,,47527,1,49836364288 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast,614ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_load_export_inductor_cache_dir'},0.9937755975720319,547ms,amg_ao_ppb_1024_load_export,2007ms,468,,,,449ms,2.1836456461214064img/s,191.0,,457.94976019859314s,,914ms,544ms,505ms,amg,431ms,1251ms,,,48,1,521ms,1024,457.94976019859314ms,552ms,462.4107701778412,2.7.0.dev20250201+cu124,506ms,,47527,1,49836364288 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast,605ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_load_export_inductor_cache_dir'},0.993802113145407,432ms,amg_ao_ppb_1024_load_export_gpu_preproc,1660ms,468,,None,,458ms,2.1564448274199335img/s,185.0,,463.72621607780457s,,784ms,428ms,468ms,amg,450ms,1598ms,,,48,1,532ms,1024,463.72621607780457ms,559ms,467.6617069244385,2.7.0.dev20250201+cu124,443ms,,47551,1,49861530112 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast,614ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_fast_export_inductor_cache_dir'},0.9937755975720319,436ms,amg_ao_ppb_1024_fast_export_cold,1701ms,468,None,,,449ms,2.1939577313018166img/s,191.0,,455.79729533195496s,,906ms,431ms,397ms,amg,431ms,1598ms,,,48,1,517ms,1024,455.79729533195496ms,556ms,460.1355531215668,2.7.0.dev20250201+cu124,512ms,,47527,1,49836364288 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast,643ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_fast_export_inductor_cache_dir'},0.9937755975720319,432ms,amg_ao_ppb_1024_fast_export,1610ms,10,None,,,468ms,2.107951108360395img/s,191.0,,474.3943045139313s,,777ms,429ms,513ms,amg,453ms,1599ms,,,48,1,552ms,1024,474.3943045139313ms,582ms,478.476078748703,2.7.0.dev20250201+cu124,440ms,,47527,1,49836364288 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast,621ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_fast_export_inductor_cache_dir'},0.993802113145407,433ms,amg_ao_ppb_1024_fast_export_gpu_preproc,1596ms,468,None,None,,452ms,2.1814478117441096img/s,185.0,,458.4111499786377s,,779ms,430ms,426ms,amg,439ms,1595ms,,,48,1,529ms,1024,458.4111499786377ms,550ms,462.8308107852936,2.7.0.dev20250201+cu124,454ms,,47551,1,49861530112 +0.22.0.dev20250201+cu124,,322ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_furious_inductor_cache_dir'},0.977748688557306,2058ms,amg_ao_ppb_1024_fast_furious_cold,780191ms,0,None,,,965ms,1.028371982635375img/s,306.0,None,972.4107782840729s,,780191ms,188ms,142ms,amg,172ms,2836ms,,,29,1,247ms,1024,972.4107782840729ms,277ms,981.1423377990723,2.7.0.dev20250201+cu124,171ms,,28335,1,29712147456 +0.22.0.dev20250201+cu124,,326ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_furious_inductor_cache_dir'},0.977748688557306,1087ms,amg_ao_ppb_1024_fast_furious,10341ms,0,None,,,187ms,5.089925733900441img/s,306.0,None,196.4665207862854s,,10341ms,164ms,133ms,amg,165ms,1096ms,,,29,1,240ms,1024,196.4665207862854ms,264ms,202.26249361038208,2.7.0.dev20250201+cu124,133ms,,28335,1,29712147456 +0.22.0.dev20250201+cu124,,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_furious_inductor_cache_dir'},,,amg_ao_ppb_1024_save_export_furious,,,,,,,0.0img/s,,None,498.73366498947144s,,,,,amg,,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast_furious,8,1,,1024,,,512.0970723628998,2.7.0.dev20250201+cu124,,0,8754,1,9179703808 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast_furious,306ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_load_export_furious_inductor_cache_dir'},0.9737371438629709,201ms,amg_ao_ppb_1024_load_export_furious_cold,1505ms,468,,,,173ms,5.561937676510263img/s,308.0,None,179.79345655441284s,,906ms,167ms,144ms,amg,162ms,906ms,,,28,1,233ms,1024,179.79345655441284ms,264ms,184.16123342514038,2.7.0.dev20250201+cu124,166ms,,27927,1,29284452864 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast_furious,305ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_load_export_furious_inductor_cache_dir'},0.9737371438629709,163ms,amg_ao_ppb_1024_load_export_furious,1499ms,468,,,,168ms,5.761707911735736img/s,308.0,None,173.55964851379395s,,799ms,158ms,128ms,amg,157ms,799ms,,,28,1,230ms,1024,173.55964851379395ms,255ms,177.849613904953,2.7.0.dev20250201+cu124,129ms,,27927,1,29284452864 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast_furious,283ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_load_export_furious_inductor_cache_dir'},0.9773021879969557,162ms,amg_ao_ppb_1024_load_export_furious_gpu_preproc,1465ms,468,,None,,152ms,6.353136573161692img/s,311.0,None,157.4025661945343s,,908ms,161ms,131ms,amg,136ms,908ms,,,28,1,208ms,1024,157.4025661945343ms,232ms,161.51681876182556,2.7.0.dev20250201+cu124,128ms,,27950,1,29308637696 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast_furious,322ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_fast_export_furious_inductor_cache_dir'},0.9737371438629709,197ms,amg_ao_ppb_1024_fast_export_furious_cold,1468ms,468,None,,,177ms,5.4535923991866575img/s,308.0,None,183.36537218093872s,,847ms,178ms,149ms,amg,166ms,848ms,,,28,1,239ms,1024,183.36537218093872ms,265ms,187.86286783218384,2.7.0.dev20250201+cu124,146ms,,27927,1,29284452864 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast_furious,314ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_fast_export_furious_inductor_cache_dir'},0.9737371438629709,171ms,amg_ao_ppb_1024_fast_export_furious,1507ms,468,None,,,175ms,5.529882243245596img/s,308.0,None,180.83567714691162s,,837ms,203ms,169ms,amg,165ms,838ms,,,28,1,235ms,1024,180.83567714691162ms,262ms,185.2059760093689,2.7.0.dev20250201+cu124,168ms,,27927,1,29284452864 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast_furious,233ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_fast_export_furious_inductor_cache_dir'},0.9738506008329433,137ms,amg_ao_ppb_1024_fast_export_furious_recompiles,50620ms,0,None,,None,244ms,4.001638144664907img/s,312.0,None,249.89765787124634s,,50620ms,131ms,136ms,amg,141ms,37015ms,,,16,1,184ms,1024,249.89765787124634ms,201ms,256.0627791881561,2.7.0.dev20250201+cu124,122ms,,15760,1,16525926912 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast_furious,287ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_fast_export_furious_inductor_cache_dir'},0.9773021879969557,161ms,amg_ao_ppb_1024_fast_export_furious_gpu_preproc,1464ms,468,None,None,,152ms,6.317038724475264img/s,311.0,None,158.3020215034485s,,789ms,158ms,138ms,amg,137ms,789ms,,,28,1,209ms,1024,158.3020215034485ms,233ms,163.1717290878296,2.7.0.dev20250201+cu124,128ms,,27950,1,29308637696 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/amg_ao_fast_furious,203ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_fast_export_furious_inductor_cache_dir'},0.9772370635291895,137ms,amg_ao_ppb_1024_fast_export_furious_gpu_preproc_recompiles,2598ms,0,None,None,None,133ms,7.189177970329722img/s,313.0,None,139.09796142578125s,,2598ms,134ms,120ms,amg,123ms,408ms,,,16,1,161ms,1024,139.09796142578125ms,175ms,144.49736833572388,2.7.0.dev20250201+cu124,116ms,,15784,1,16551617024 +0.22.0.dev20250201+cu124,,282ms,None,,130ms,baseline_sps,593ms,0,,,,132ms,7.2705623547104485img/s,,,137.5409426689148s,None,593ms,106ms,139ms,sps,116ms,314ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,181ms,1,137.5409426689148ms,226ms,141.5770561695099,2.7.0.dev20250201+cu124,102ms,,1337,,1402492416 +0.22.0.dev20250201+cu124,,220ms,None,1.0,112ms,sps_ao,647ms,0,,,,121ms,7.953360671304035img/s,0.0,,125.73301291465759s,,647ms,104ms,206ms,sps,110ms,228ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,158ms,1,125.7330129146576ms,209ms,129.6152720451355,2.7.0.dev20250201+cu124,201ms,,1339,,1404989952 +0.22.0.dev20250201+cu124,,222ms,None,1.0,106ms,sps_ao_ppb_1_basic,562ms,0,,,,123ms,7.782472062347266img/s,0.0,,128.4938759803772s,,562ms,102ms,118ms,sps,110ms,235ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,170ms,1,128.4938759803772ms,211ms,132.38522243499756,2.7.0.dev20250201+cu124,124ms,,1339,1,1404989952 +0.22.0.dev20250201+cu124,,215ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_inductor_cache_dir'},,1752ms,sps_ao_ppb_1_fast_cold,319954ms,0,None,,,436ms,2.25470066554133img/s,,,443.51785373687744s,,319954ms,128ms,93ms,sps,102ms,2070ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,158ms,1,443.51785373687744ms,204ms,454.28877544403076,2.7.0.dev20250201+cu124,91ms,,1343,1,1408784896 +0.22.0.dev20250201+cu124,,215ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_inductor_cache_dir'},0.9998689295053482,1006ms,sps_ao_ppb_1_fast,8947ms,0,None,,,124ms,7.688401953604155img/s,0.0,,130.06604051589966s,,8947ms,97ms,96ms,sps,100ms,1014ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,160ms,1,130.06604051589966ms,204ms,136.32297778129578,2.7.0.dev20250201+cu124,93ms,,1302,1,1366200320 +0.22.0.dev20250201+cu124,,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_inductor_cache_dir'},,,sps_ao_ppb_1_save_export,,,,,,,0.0img/s,,,285.2198317050934s,,,,,sps,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast,1,1,,1,,,296.2626416683197,2.7.0.dev20250201+cu124,,0,1326,1,1390578176 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast,218ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_load_export_inductor_cache_dir'},0.9998687945604324,99ms,sps_ao_ppb_1_load_export_cold,433ms,0,,,,154ms,6.231343583764652img/s,0.0,,160.47903418540955s,,433ms,96ms,95ms,sps,140ms,232ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,6,1,212ms,1,160.47903418540955ms,215ms,164.9347858428955,2.7.0.dev20250201+cu124,92ms,,5949,1,6238665728 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast,222ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_load_export_inductor_cache_dir'},0.9998687945604324,102ms,sps_ao_ppb_1_load_export,571ms,0,,,,134ms,7.13271552723891img/s,0.0,,140.19905829429626s,,571ms,96ms,103ms,sps,109ms,276ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,6,1,210ms,1,140.19905829429626ms,215ms,144.91857886314392,2.7.0.dev20250201+cu124,97ms,,5949,1,6238665728 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast,178ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_load_export_inductor_cache_dir'},0.9998684992790222,103ms,sps_ao_ppb_1_load_export_gpu_preproc,546ms,0,,None,,114ms,8.309061811617058img/s,0.0,,120.35053086280823s,,546ms,108ms,98ms,sps,104ms,198ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,6,1,160ms,1,120.35053086280823ms,168ms,125.0929605960846,2.7.0.dev20250201+cu124,95ms,,5971,1,6261886976 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast,218ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_fast_export_inductor_cache_dir'},0.9998687945604324,115ms,sps_ao_ppb_1_fast_export_cold,469ms,0,None,,,117ms,8.127192953841458img/s,0.0,,123.04371333122253s,,469ms,96ms,96ms,sps,102ms,231ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,6,1,170ms,1,123.04371333122253ms,207ms,127.78972721099854,2.7.0.dev20250201+cu124,93ms,,5949,1,6238665728 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast,214ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_fast_export_inductor_cache_dir'},0.9998687945604324,99ms,sps_ao_ppb_1_fast_export,457ms,0,None,,,113ms,8.353357150253501img/s,0.0,,119.71234822273254s,,457ms,98ms,127ms,sps,102ms,226ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,6,1,150ms,1,119.71234822273254ms,194ms,124.14609551429749,2.7.0.dev20250201+cu124,101ms,,5949,1,6238665728 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast,174ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_fast_export_inductor_cache_dir'},0.9998684992790222,158ms,sps_ao_ppb_1_fast_export_gpu_preproc,494ms,0,None,None,,111ms,8.544958106296153img/s,0.0,,117.02807521820068s,,494ms,161ms,161ms,sps,102ms,188ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,6,1,155ms,1,117.02807521820068ms,165ms,121.18485426902771,2.7.0.dev20250201+cu124,155ms,,5971,1,6261886976 +0.22.0.dev20250201+cu124,,72ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_furious_inductor_cache_dir'},0.9996836956143379,2866ms,sps_ao_ppb_1_fast_furious_cold,565385ms,0,None,,,602ms,1.6434863733339082img/s,0.0,None,608.4626049995422s,,565385ms,28ms,25ms,sps,30ms,3429ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,0,1,59ms,1,608.4626049995422ms,64ms,619.5543768405914,2.7.0.dev20250201+cu124,20ms,,861,1,903450624 +0.22.0.dev20250201+cu124,,72ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_furious_inductor_cache_dir'},0.9996836956143379,617ms,sps_ao_ppb_1_fast_furious,7989ms,0,None,,,45ms,19.35863964467831img/s,0.0,None,51.656522274017334s,,7989ms,23ms,22ms,sps,32ms,625ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,0,1,63ms,1,51.656522274017334ms,68ms,58.16215395927429,2.7.0.dev20250201+cu124,18ms,,861,1,903450624 +0.22.0.dev20250201+cu124,,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_furious_inductor_cache_dir'},,,sps_ao_ppb_1_save_export_furious,,,,,,,0.0img/s,,None,367.0964250564575s,,,,,sps,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast_furious,0,1,,1,,,379.5168604850769,2.7.0.dev20250201+cu124,,0,861,1,903450624 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast_furious,72ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_load_export_furious_inductor_cache_dir'},0.999670289516449,50ms,sps_ao_ppb_1_load_export_furious_cold,763ms,0,,,,42ms,20.10115843340511img/s,0.0,None,49.7483766078949s,,763ms,24ms,45ms,sps,35ms,78ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,66ms,1,49.7483766078949ms,68ms,54.233083724975586,2.7.0.dev20250201+cu124,57ms,,1686,1,1768025088 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast_furious,69ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_load_export_furious_inductor_cache_dir'},0.999670289516449,61ms,sps_ao_ppb_1_load_export_furious,683ms,0,,,,38ms,22.96430070006913img/s,0.0,None,43.54585027694702s,,683ms,51ms,54ms,sps,31ms,80ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,63ms,1,43.54585027694702ms,66ms,47.646597385406494,2.7.0.dev20250201+cu124,57ms,,1686,1,1768025088 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast_furious,28ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_load_export_furious_inductor_cache_dir'},0.9996740021705628,20ms,sps_ao_ppb_1_load_export_furious_gpu_preproc,658ms,0,,None,,21ms,39.00878151072386img/s,0.0,None,25.635253429412842s,,658ms,19ms,21ms,sps,19ms,31ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,25ms,1,25.635253429412842ms,26ms,30.03287935256958,2.7.0.dev20250201+cu124,17ms,,1711,1,1794153472 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast_furious,77ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_fast_export_furious_inductor_cache_dir'},0.999670289516449,23ms,sps_ao_ppb_1_fast_export_furious_cold,667ms,0,None,,,40ms,21.668760920646257img/s,0.0,None,46.14938545227051s,,667ms,21ms,23ms,sps,33ms,121ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,67ms,1,46.14938545227051ms,69ms,51.00432109832764,2.7.0.dev20250201+cu124,18ms,,1686,1,1768025088 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast_furious,71ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_fast_export_furious_inductor_cache_dir'},0.999670289516449,24ms,sps_ao_ppb_1_fast_export_furious,770ms,0,None,,,35ms,24.842548071007272img/s,0.0,None,40.253519773483276s,,770ms,23ms,23ms,sps,30ms,87ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,60ms,1,40.253519773483276ms,66ms,45.05125379562378,2.7.0.dev20250201+cu124,19ms,,1686,1,1768025088 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast_furious,72ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_fast_export_furious_inductor_cache_dir'},0.22583148045595317,25ms,sps_ao_ppb_1_fast_export_furious_recompiles,8888ms,0,None,,None,45ms,19.64123137979746img/s,0.0,None,50.913304805755615s,,8888ms,24ms,23ms,sps,31ms,107ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,62ms,1,50.913304805755615ms,67ms,57.28812289237976,2.7.0.dev20250201+cu124,19ms,,1686,1,1768025088 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast_furious,30ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_fast_export_furious_inductor_cache_dir'},0.9996740021705628,21ms,sps_ao_ppb_1_fast_export_furious_gpu_preproc,764ms,0,None,None,,22ms,36.6053844956782img/s,0.0,None,27.318385362625122s,,764ms,19ms,20ms,sps,20ms,32ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,27ms,1,27.318385362625122ms,28ms,32.168028831481934,2.7.0.dev20250201+cu124,17ms,,1711,1,1794153472 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/sps_ao_fast_furious,28ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/sps_fast_export_furious_inductor_cache_dir'},0.2360341116612085,21ms,sps_ao_ppb_1_fast_export_furious_gpu_preproc_recompiles,2423ms,0,None,None,None,22ms,36.49431885806781img/s,0.0,None,27.401525259017944s,,2423ms,19ms,21ms,sps,19ms,32ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,25ms,1,27.401525259017944ms,26ms,32.687700271606445,2.7.0.dev20250201+cu124,17ms,,1711,1,1794153472 +0.22.0.dev20250201+cu124,,1271ms,None,,883ms,baseline_mps,2023ms,525,,,,363ms,2.673329025663599img/s,,,374.06544065475464s,None,783ms,1250ms,552ms,mps,276ms,1639ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,1,1,681ms,,374.06544065475464ms,914ms,378.37125968933105,2.7.0.dev20250201+cu124,264ms,,1337,,1402492416 +0.22.0.dev20250201+cu124,,236ms,None,0.999999164044857,122ms,mps_ao,577ms,0,,,,135ms,7.037101001300518img/s,0.0,,142.10397148132324s,,577ms,118ms,139ms,mps,121ms,343ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,8,1,189ms,,142.10397148132324ms,222ms,146.95012307167053,2.7.0.dev20250201+cu124,150ms,,8021,,8411175424 +0.22.0.dev20250201+cu124,,247ms,None,0.999999164044857,119ms,mps_ao_ppb_None_basic,504ms,0,,,,148ms,6.436594044650894img/s,0.0,,155.36167001724243s,,504ms,116ms,238ms,mps,126ms,435ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,8,1,225ms,,155.36167001724243ms,233ms,159.40889310836792,2.7.0.dev20250201+cu124,103ms,,8021,1,8411175424 +0.22.0.dev20250201+cu124,,235ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_inductor_cache_dir'},,1555ms,mps_ao_ppb_None_fast_cold,333308ms,0,None,,,591ms,1.6704230439798613img/s,,,598.6507451534271s,,333308ms,126ms,116ms,mps,129ms,62595ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,8,1,215ms,,598.6507451534271ms,221ms,608.3473885059357,2.7.0.dev20250201+cu124,97ms,,8021,1,8411176448 +0.22.0.dev20250201+cu124,,239ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_inductor_cache_dir'},0.9983837329149247,427ms,mps_ao_ppb_None_fast,8617ms,0,None,,,144ms,6.6146677704234085img/s,0.0,,151.17917251586914s,,8617ms,107ms,230ms,mps,113ms,1446ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,8,1,213ms,,151.17917251586914ms,218ms,156.3648726940155,2.7.0.dev20250201+cu124,94ms,,8021,1,8411176448 +0.22.0.dev20250201+cu124,,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_inductor_cache_dir'},,,mps_ao_ppb_None_save_export,,,,,,,0.0img/s,,,206.32550930976868s,,,,,mps,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast,1,1,,,,,214.1670503616333,2.7.0.dev20250201+cu124,,0,1326,1,1390578176 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast,229ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_load_export_inductor_cache_dir'},0.9983834671974182,219ms,mps_ao_ppb_None_load_export_cold,481ms,0,,,,126ms,7.508148238612264img/s,0.0,,133.1886329650879s,,481ms,133ms,138ms,mps,112ms,267ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,7,1,175ms,,133.1886329650879ms,214ms,137.3904402256012,2.7.0.dev20250201+cu124,102ms,,7206,1,7556661248 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast,239ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_load_export_inductor_cache_dir'},0.9983834671974182,113ms,mps_ao_ppb_None_load_export,467ms,0,,,,123ms,7.699697903486223img/s,0.0,,129.87522530555725s,,467ms,109ms,159ms,mps,110ms,281ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,7,1,169ms,,129.87522530555725ms,210ms,133.87165689468384,2.7.0.dev20250201+cu124,103ms,,7206,1,7556661248 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast,217ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_load_export_inductor_cache_dir'},0.9984678574204445,150ms,mps_ao_ppb_None_load_export_gpu_preproc,596ms,0,,None,,138ms,6.876547118132811img/s,0.0,,145.42182040214539s,,596ms,174ms,146ms,mps,130ms,247ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,7,1,182ms,,145.42182040214539ms,194ms,149.7709481716156,2.7.0.dev20250201+cu124,96ms,,7230,1,7581827072 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast,223ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_fast_export_inductor_cache_dir'},0.998383906185627,109ms,mps_ao_ppb_None_fast_export_cold,63209ms,0,None,,,279ms,3.46900314658375img/s,0.0,,288.2672507762909s,,63209ms,108ms,139ms,mps,108ms,55253ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,7,1,149ms,,288.2672507762909ms,192ms,295.21097111701965,2.7.0.dev20250201+cu124,210ms,,7206,1,7556661248 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast,225ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_fast_export_inductor_cache_dir'},0.998383847117424,219ms,mps_ao_ppb_None_fast_export,3408ms,0,None,,,127ms,7.378673337507828img/s,0.0,,135.52571773529053s,,3408ms,131ms,133ms,mps,110ms,1527ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,7,1,162ms,,135.52571773529053ms,210ms,140.8395688533783,2.7.0.dev20250201+cu124,211ms,,7206,1,7556661248 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast,197ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_fast_export_inductor_cache_dir'},0.998423279285431,176ms,mps_ao_ppb_None_fast_export_gpu_preproc,4037ms,0,None,None,,139ms,6.776701628632778img/s,0.0,,147.5644133090973s,,4037ms,111ms,142ms,mps,125ms,1977ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,7,1,177ms,,147.5644133090973ms,182ms,154.06113982200623,2.7.0.dev20250201+cu124,108ms,,7230,1,7581827072 +0.22.0.dev20250201+cu124,,90ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_furious_inductor_cache_dir'},0.9973100498914719,1126ms,mps_ao_ppb_None_fast_furious_cold,593416ms,0,None,,,732ms,1.3513049945474962img/s,0.0,None,740.0253858566284s,,593416ms,69ms,62ms,mps,45ms,58562ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,4,1,71ms,,740.0253858566284ms,75ms,752.7034668922424,2.7.0.dev20250201+cu124,55ms,,4222,1,4427842560 +0.22.0.dev20250201+cu124,,80ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_furious_inductor_cache_dir'},0.9973101171851159,563ms,mps_ao_ppb_None_fast_furious,9626ms,0,None,,,51ms,15.845165465673302img/s,0.0,None,63.110732555389404s,,9626ms,70ms,68ms,mps,36ms,1443ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,4,1,64ms,,63.110732555389404ms,70ms,68.8342227935791,2.7.0.dev20250201+cu124,60ms,,4222,1,4427842560 +0.22.0.dev20250201+cu124,,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_furious_inductor_cache_dir'},,,mps_ao_ppb_None_save_export_furious,,,,,,,0.0img/s,,None,310.3892893791199s,,,,,mps,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast_furious,0,1,,,,,319.1325914859772,2.7.0.dev20250201+cu124,,0,861,1,903450624 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast_furious,88ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_load_export_furious_inductor_cache_dir'},0.9971953355669976,59ms,mps_ao_ppb_None_load_export_furious_cold,747ms,0,,,,48ms,18.330754801750256img/s,0.0,None,54.55312728881836s,,747ms,39ms,70ms,mps,43ms,211ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,3,1,72ms,,54.55312728881836ms,80ms,58.5643265247345,2.7.0.dev20250201+cu124,68ms,,3813,1,3998387200 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast_furious,94ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_load_export_furious_inductor_cache_dir'},0.9971953355669976,64ms,mps_ao_ppb_None_load_export_furious,807ms,0,,,,57ms,15.401551852759791img/s,0.0,None,64.92852210998535s,,807ms,54ms,44ms,mps,53ms,310ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,3,1,81ms,,64.92852210998535ms,85ms,69.41558504104614,2.7.0.dev20250201+cu124,70ms,,3813,1,3998387200 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast_furious,53ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_load_export_furious_inductor_cache_dir'},0.9970021231770515,38ms,mps_ao_ppb_None_load_export_furious_gpu_preproc,671ms,0,,None,,34ms,24.16679656674749img/s,0.0,None,41.379087924957275s,,671ms,37ms,37ms,mps,31ms,185ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,3,1,41ms,,41.379087924957275ms,43ms,45.440187215805054,2.7.0.dev20250201+cu124,31ms,,3837,1,4023553024 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast_furious,78ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_fast_export_furious_inductor_cache_dir'},0.9965503122806549,66ms,mps_ao_ppb_None_fast_export_furious_cold,82119ms,0,None,,,201ms,4.7872645002077805img/s,0.0,None,208.88755989074707s,,82119ms,64ms,68ms,mps,34ms,58377ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,3,1,65ms,,208.88755989074707ms,70ms,217.18880224227905,2.7.0.dev20250201+cu124,58ms,,3813,1,3998387200 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast_furious,77ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_fast_export_furious_inductor_cache_dir'},0.9965502863526344,58ms,mps_ao_ppb_None_fast_export_furious,3781ms,0,None,,,41ms,20.721952369178233img/s,0.0,None,48.25800108909607s,,3781ms,36ms,70ms,mps,31ms,1725ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,3,1,60ms,,48.25800108909607ms,67ms,54.08296799659729,2.7.0.dev20250201+cu124,27ms,,3813,1,3998387200 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast_furious,79ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_fast_export_furious_inductor_cache_dir'},0.9955503289103508,58ms,mps_ao_ppb_None_fast_export_furious_recompiles,14159ms,0,None,,None,60ms,14.717359525686513img/s,0.0,None,67.94697093963623s,,14159ms,30ms,33ms,mps,34ms,7675ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,3,1,63ms,,67.94697093963623ms,69ms,74.24112939834595,2.7.0.dev20250201+cu124,24ms,,3813,1,3998387200 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast_furious,45ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_fast_export_furious_inductor_cache_dir'},0.9961833162903786,26ms,mps_ao_ppb_None_fast_export_furious_gpu_preproc,4055ms,0,None,None,,30ms,27.10634670561908img/s,0.0,None,36.89172911643982s,,4055ms,26ms,29ms,mps,22ms,1531ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,3,1,32ms,,36.89172911643982ms,36ms,41.898804664611816,2.7.0.dev20250201+cu124,19ms,,3837,1,4023553024 +0.22.0.dev20250201+cu124,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/exported_models/mps_ao_fast_furious,43ms,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/mps_fast_export_furious_inductor_cache_dir'},0.9956747673153877,25ms,mps_ao_ppb_None_fast_export_furious_gpu_preproc_recompiles,5487ms,0,None,None,None,32ms,25.886837983308926img/s,0.0,None,38.62966966629028s,,5487ms,26ms,29ms,mps,23ms,1561ms,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_21/amg_baseline_annotations,,3,1,32ms,,38.62966966629028ms,36ms,44.1503472328186,2.7.0.dev20250201+cu124,19ms,,3837,1,4023553024 From d3306b22b0e9cba09762c335757c1dcfbd96f170 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Wed, 12 Feb 2025 11:05:48 -0800 Subject: [PATCH 072/115] Add mx_fp8_bf16 kernel (#1637) * Add mx_fp8_bf16 kernel stack-info: PR: https://github.com/pytorch/ao/pull/1637, branch: drisspg/stack/31 * Add mx_fp4_kernel (#1661) stack-info: PR: https://github.com/pytorch/ao/pull/1661 --- setup.py | 9 +- test/prototype/mx_formats/test_mx_mm.py | 74 +++++ .../cuda/mx_kernels/mx_fp_cutlass_kernels.cu | 285 ++++++++++++++++++ torchao/ops.py | 80 +++++ torchao/prototype/mx_formats/utils.py | 53 ++++ 5 files changed, 497 insertions(+), 4 deletions(-) create mode 100644 test/prototype/mx_formats/test_mx_mm.py create mode 100644 torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu create mode 100644 torchao/prototype/mx_formats/utils.py diff --git a/setup.py b/setup.py index 67a8d2e576..6ee93bc9ab 100644 --- a/setup.py +++ b/setup.py @@ -215,10 +215,7 @@ def get_extensions(): extra_link_args = [] extra_compile_args = { "cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"], - "nvcc": [ - "-O3" if not debug_mode else "-O0", - "-t=0", - ], + "nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"], } if not IS_WINDOWS: @@ -257,12 +254,16 @@ def get_extensions(): use_cutlass = True cutlass_dir = os.path.join(third_party_path, "cutlass") cutlass_include_dir = os.path.join(cutlass_dir, "include") + cutlass_tools_include_dir = os.path.join( + cutlass_dir, "tools", "util", "include" + ) cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir) if use_cutlass: extra_compile_args["nvcc"].extend( [ "-DTORCHAO_USE_CUTLASS", "-I" + cutlass_include_dir, + "-I" + cutlass_tools_include_dir, "-I" + cutlass_extensions_include_dir, ] ) diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py new file mode 100644 index 0000000000..7c66c5d053 --- /dev/null +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -0,0 +1,74 @@ +import pytest +import torch + +from torchao.float8.float8_utils import compute_error +from torchao.ops import mx_fp4_bf16, mx_fp8_bf16 +from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor +from torchao.prototype.mx_formats.utils import to_blocked +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100 + +if not TORCH_VERSION_AT_LEAST_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +def run_matrix_test(M: int, K: int, N: int, format) -> float: + dtype = torch.bfloat16 + device = torch.device("cuda") + + a = torch.rand((M, K), dtype=dtype, device=device) + b = torch.rand((N, K), dtype=dtype, device=device) + + fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4 + mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16 + + a_mx = MXTensor.to_mx(a, fmt, 32) + b_mx = MXTensor.to_mx(b, fmt, 32) + + a_data = a_mx._data + b_data = b_mx._data + assert b_data.is_contiguous() + b_data = b_data.transpose(-1, -2) + + a_scale = a_mx._scale_e8m0.view(M, K // 32) + b_scale = b_mx._scale_e8m0.view(N, K // 32) + + a_scale_block = to_blocked(a_scale) + b_scale_block = to_blocked(b_scale) + + out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose( + -1, -2 + ) + out = mx_func(a_data, b_data, a_scale_block, b_scale_block) + + return compute_error(out_hp, out).item() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" +) +@pytest.mark.parametrize( + "size", + [ + (128, 128, 128), + (256, 256, 256), + (384, 384, 384), # Small + (512, 512, 512), + (768, 768, 768), # Medium + (1024, 1024, 1024), + (8192, 8192, 8192), # Large + (128, 256, 384), + (256, 384, 512), # Non-square + (129, 256, 384), + (133, 512, 528), # Non-aligned + ], + ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}", +) +@pytest.mark.parametrize("format", ["fp8", "fp4"]) +def test_matrix_multiplication(size, format): + M, K, N = size + sqnr = run_matrix_test(M, K, N, format) + threshold = 80.0 + assert ( + sqnr >= threshold + ), f"{format} SQNR {sqnr} below threshold for dims {M}x{K}x{N}" diff --git a/torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu b/torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu new file mode 100644 index 0000000000..e01d363ec3 --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/mx_fp_cutlass_kernels.cu @@ -0,0 +1,285 @@ +#include + +#include +#include +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) +#define BUILD_MX_KERNELS_CUTLASS +#endif + +#if defined(BUILD_MX_KERNELS_CUTLASS) + +#include "cute/tensor.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/util/packed_stride.hpp" + + +#endif + +namespace torchao { + +#if defined(BUILD_MX_KERNELS_CUTLASS) +namespace { + +using namespace cute; + +template +constexpr int GetAlignment() { + if constexpr (std::is_same_v>) + return 32; + return 16; +} + +template +void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale, + at::Tensor& b_scale, at::Tensor& out, int M, int K, int N) { + // A matrix configuration + using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand + constexpr int AlignmentA = GetAlignment(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + constexpr int AlignmentB = GetAlignment(); // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand + using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand + using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand + constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + // Kernel functional config + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + PerSmTileShape_MNK, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Reference device GEMM implementation type + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + + // Initialize strides using packed stride configuration + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, make_shape(M, K, 1)); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, make_shape(N, K, 1)); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, make_shape(M, N, 1)); + + // Initialize scale factor layouts using block scaled configuration + auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); + + using DtypeA = typename ElementA::DataType; + using DtypeB = typename ElementB::DataType; + using DtypeScaleA = typename ElementA::ScaleFactorType; + using DtypeScaleB = typename ElementB::ScaleFactorType; + using DtypeOut = ElementD; + + Gemm gemm; + + auto A_ptr = reinterpret_cast(a.data_ptr()); + auto B_ptr = reinterpret_cast(b.data_ptr()); + auto SFA_ptr = reinterpret_cast(a_scale.data_ptr()); + auto SFB_ptr = reinterpret_cast(b_scale.data_ptr()); + auto out_ptr = reinterpret_cast(out.data_ptr()); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + { // Mainloop arguments + A_ptr, stride_A, + B_ptr, stride_B, + SFA_ptr, layout_SFA, + SFB_ptr, layout_SFB + }, + { // Epilogue arguments + {1.0, 0.0}, + nullptr, StrideC{}, // No bias for now + out_ptr, stride_D + } + }; + + // arguments.scheduler.max_swizzle_size = 8; + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot implement"); + // Allocate workspace memory + size_t workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = a.new_empty( + {static_cast(workspace_size)}, + at::TensorOptions().dtype(at::kByte)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot initialize"); + + status = gemm.run(at::cuda::getCurrentCUDAStream()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot run", cutlass::cutlassGetStatusString(status)); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +} +} +#endif + +void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale){ + TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor"); + TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor"); + TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor"); + TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor"); + + // Check matrix dimensions + TORCH_CHECK(a.dim() == 2, "a must be a matrix"); + TORCH_CHECK(b.dim() == 2, "b must be a matrix"); + + // Get dimensions + auto M = a.size(0); + auto K = a.size(1); + auto N = b.size(1); + + TORCH_CHECK(b.size(0) == K, + "Incompatible matrix dimensions: a is ", M, "x", K, " but b is ", b.size(0), "x", N); + + // Needed for TMA store + TORCH_CHECK(N % 8 == 0, "N must be a multiple of 16 but got, ", N); + + // Check 16-byte alignment for input tensors + TORCH_CHECK( + reinterpret_cast(a.data_ptr()) % 16 == 0, + "Input tensor 'a' must be 16-byte aligned"); + TORCH_CHECK( + reinterpret_cast(b.data_ptr()) % 16 == 0, + "Input tensor 'b' must be 16-byte aligned"); + + auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; }; + auto num_k_blocks = ceil_div(K, 32); + // For a_scale, we expect elements or M* ceil(K/32) elements + auto expected_a_scale_size = 128 * ceil_div(M, 128) * num_k_blocks; + TORCH_CHECK(a_scale.numel() == expected_a_scale_size, "Expected b_scale_size to be ", expected_a_scale_size, " but got ", a_scale.numel()); + + // For b_scale, we expect N * ceil(K/32) elements + auto expected_b_scale_size = 128 * ceil_div(N, 128) * num_k_blocks; + TORCH_CHECK(b_scale.numel() == expected_b_scale_size, "Expected a_scale_size to be ", expected_b_scale_size, " but got ", b_scale.numel()); + + // Check tensor strides for optimal memory layout + TORCH_CHECK( + a.stride(1) == 1, + "Input tensor 'a' must be contiguous in the K dimension (row-major)"); + TORCH_CHECK( + b.stride(0) == 1, + "Input tensor 'b' must be contiguous in the K dimension (column-major)"); +} + + +at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, + at::Tensor b_scale) { +#if defined(BUILD_MX_KERNELS_CUTLASS) + validate(a, b, a_scale, b_scale); + auto M = a.size(0); + auto K = a.size(1); + auto N = b.size(1); + + auto out = + at::empty({M, N}, a.options().dtype(at::kBFloat16)); + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float8_t; + using ElementD = cutlass::bfloat16_t; + + using MmaTileShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_2,_1,_1>; + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + run_gemm(a, b, a_scale, b_scale, out, M, K, N); + return out; + #else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + +at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, + at::Tensor b_scale) { +#if defined(BUILD_MX_KERNELS_CUTLASS) + TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor"); + TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor"); + TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor"); + TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor"); + + auto M = a.size(0); + auto K = a.size(1) * 2; + auto N = b.size(1); + + auto out = + at::empty({M, N}, a.options().dtype(at::kBFloat16)); + using ElementA = cutlass::mx_float4_t; + using ElementB = cutlass::mx_float4_t; + using ElementD = cutlass::bfloat16_t; + + using MmaTileShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_2,_1,_1>; + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + run_gemm(a, b, a_scale, b_scale, out, M, K, N); + return out; +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16); +} +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::mx_fp4_bf16", &mx_fp4_bf16); +} + + + +} // namespace torchao diff --git a/torchao/ops.py b/torchao/ops.py index 8b573876f2..56980b17f1 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -25,6 +25,8 @@ lib.define( "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) +lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor") +lib.define("mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor") def register_custom_op(name): @@ -592,3 +594,81 @@ def _( bias: Tensor, ) -> Tensor: return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) + + +def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): + """Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor. + + This op is prototype subject to change. + + Note: The mx scales are E8MO tensors store in uint8 tensors (for now). + The layout of the scales is very particular, see: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + A: fp8 tensor w/ dtype = torch.float8_e4m3fn + B: fp8 tensor w/ dtype = torch.float8_e4m3fn + A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout + B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout + + Returns: + MXN bf16 Tensor + + """ + torch._check( + A.dtype == torch.float8_e4m3fn, + lambda: f"Input tensor A must be float8_e4m3fn, got {A.dtype}", + ) + torch._check( + B.dtype == torch.float8_e4m3fn, + lambda: f"Input tensor B must be float8_e4m3fn, got {B.dtype}", + ) + + # TODO - Once e8m0 dtype is added to core udpate + # Check scale tensors are uint8 + torch._check( + A_scale.dtype == torch.uint8, + lambda: f"A_scale tensor must be uint8, got {A_scale.dtype}", + ) + torch._check( + B_scale.dtype == torch.uint8, + lambda: f"B_scale tensor must be uint8, got {B_scale.dtype}", + ) + return torch.ops.torchao.mx_fp8_bf16.default(A, B, A_scale, B_scale) + + +@register_custom_op("torchao::mx_fp8_bf16") +def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): + """Meta impl for mx_fp8_bf16""" + return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device) + + +def mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): + """Defines a matmul between two fp4 tensors w/ MX scales in E8MO and returns a bf16 tensor. + + The expected format is fp4_e2m1 specified: + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final.pdf (Section 5.3.3) + + Note: The mx scales are E8MO tensors stored in uint8 tensors (for now). + The layout of the scales is very particular, see: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + + Args: + A: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1) + B: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1) + A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout + B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout + + Returns: + MXN bf16 Tensor + + """ + return torch.ops.torchao.mx_fp4_bf16.default(A, B, A_scale, B_scale) + + +@register_custom_op("torchao::mx_fp4_bf16") +def meta_mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): + """Meta impl for mx_fp4_bf16""" + # Assume that the contraction happens in the K dim thus M,N are perserved post bit pack + return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device) diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py new file mode 100644 index 0000000000..4cdc26109d --- /dev/null +++ b/torchao/prototype/mx_formats/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +Tensor = torch.Tensor + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def to_blocked(input_matrix) -> Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Pad out and view as tiles of (128, 4) + padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128)) + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + + # rearrange all tiles + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + # Layout rearranged tiles according to second pic + return rearranged.flatten() + + +def _to_blocked_single(scales: Tensor) -> Tensor: + """Assume that we have a 128x4 block of scales in K Major order + + To see more information on the individual tile layout: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + assert scales.shape == (128, 4) + scales_tiled = scales.view(4, 32, 4) # view as 4 - (32, 4) tiles + return scales_tiled.transpose(0, 1).reshape(32, 16) # Interleave tiles From dff29c0c8b6b2b8ff5834743ff8f106cd564c5b3 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Thu, 13 Feb 2025 09:47:43 -0800 Subject: [PATCH 073/115] Fix use_hqq for int4_weight_only quantize (#1707) Fix HQQ call for int4_weight_only quantize --- torchao/_models/llama/generate.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index b1d3475601..69b0fb6e99 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -420,10 +420,9 @@ def ffn_or_attn_only(mod, fqn): else: quantize_(model, int8_dynamic_activation_int8_weight()) if "int4wo" in quantization: + use_hqq = False if "hqq" in quantization: use_hqq = True - else: - use_hqq = False group_size = int(quantization.split("-")[1]) assert ( group_size @@ -434,7 +433,7 @@ def ffn_or_attn_only(mod, fqn): 256, ] ), f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" - quantize_(model, int4_weight_only(group_size=group_size)) + quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) elif "int8adq-int4w-symm" in quantization: from torchao.dtypes import CutlassInt4PackedLayout From 52f4737f22bd4e650cfb6730a2afda2609c8a314 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 13 Feb 2025 16:23:19 -0800 Subject: [PATCH 074/115] [bc-breaking] enable direct configuration in quantize_ (#1595) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 25 ++- test/hqq/test_hqq_affine.py | 7 +- test/quantization/test_qat.py | 2 +- test/quantization/test_quant_api.py | 25 +++ torchao/core/__init__.py | 0 torchao/core/config.py | 29 +++ torchao/quantization/__init__.py | 6 + torchao/quantization/qat/__init__.py | 4 + torchao/quantization/qat/api.py | 118 +++++++----- torchao/quantization/quant_api.py | 224 +++++++++++++---------- torchao/quantization/transform_module.py | 46 +++++ 11 files changed, 334 insertions(+), 152 deletions(-) create mode 100644 torchao/core/__init__.py create mode 100644 torchao/core/config.py create mode 100644 torchao/quantization/transform_module.py diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 52b25dab82..53ca470b04 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -8,6 +8,7 @@ run_tests, ) +from torchao.core.config import AOBaseConfig from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, @@ -16,6 +17,7 @@ int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_weight_only, + quantize_, ) from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.utils import ( @@ -82,7 +84,8 @@ def test_tensor_core_layout_transpose(self): t = linear.weight shape = t.shape apply_int4_weight_only_quant = int4_weight_only(group_size=32) - ql = apply_int4_weight_only_quant(linear) + quantize_(linear, apply_int4_weight_only_quant) + ql = linear aqt = ql.weight aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) @@ -102,7 +105,12 @@ def test_tensor_core_layout_transpose(self): ) def test_weights_only(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseConfig): + quantize_(linear, apply_quant) + ql = linear + else: + # TODO(#1690): delete this once config migration is done + ql = apply_quant(linear) with tempfile.NamedTemporaryFile() as f: torch.save(ql.state_dict(), f) f.seek(0) @@ -181,7 +189,12 @@ def apply_uint6_weight_only_quant(linear): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_print_quantized_module(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseConfig): + quantize_(linear, apply_quant) + ql = linear + else: + # TODO(#1690): delete this once config migration is done + ql = apply_quant(linear) assert "AffineQuantizedTensor" in str(ql) @@ -195,7 +208,11 @@ def test_flatten_unflatten(self, device, dtype): apply_quant_list = get_quantization_functions(False, True, device) for apply_quant in apply_quant_list: linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - ql = apply_quant(linear) + if isinstance(apply_quant, AOBaseConfig): + quantize_(linear, apply_quant) + else: + # TODO(#1690): delete this once config migration is done + ql = apply_quant(linear) lp_tensor = ql.weight tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() tensor_data_dict = { diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 381886d594..096c9d26ba 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -6,6 +6,7 @@ MappingType, ZeroPointDomain, int4_weight_only, + quantize_, uintx_weight_only, ) from torchao.utils import ( @@ -51,9 +52,9 @@ def _eval_hqq(dtype): ) dummy_linear.weight.data = W if dtype == torch.uint4: - q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)( - dummy_linear - ).weight + config = int4_weight_only(group_size=max(block_size), use_hqq=True) + quantize_(dummy_linear, config) + q_tensor_hqq = dummy_linear.weight else: q_tensor_hqq = uintx_weight_only( dtype, group_size=max(block_size), use_hqq=True diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 8a78b8b387..82324394a8 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self): @unittest.skipIf( not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) - def test_quantize_api(self): + def test_quantize_api_standalone(self): """ Test that the following: diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index caba1cf31f..acd9b50c5a 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -40,6 +40,7 @@ Int4WeightOnlyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, ) +from torchao.quantization.utils import compute_error from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -783,6 +784,30 @@ def test_int4wo_cpu(self, dtype, x_dim): assert "_weight_int4pack_mm_for_cpu" in code[0] assert "aten.mm.default" not in code[0] + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_int4_weight_only_numerics(self): + """ + Simple test of e2e int4_weight_only workflow, comparing numerics + to a bfloat16 baseline. + """ + # set up inputs + x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 + # is that expected? + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() + m_int4_wo = copy.deepcopy(m_ref) + + # quantize + quantize_(m_int4_wo, int4_weight_only()) + + with torch.no_grad(): + y_ref = m_ref(x) + y_int4_wo = m_int4_wo(x) + + sqnr = compute_error(y_ref, y_int4_wo) + assert sqnr >= 20, f"SQNR {sqnr} is too low" + class TestMultiTensorFlow(TestCase): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") diff --git a/torchao/core/__init__.py b/torchao/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/core/config.py b/torchao/core/config.py new file mode 100644 index 0000000000..14a7b8dc66 --- /dev/null +++ b/torchao/core/config.py @@ -0,0 +1,29 @@ +import abc + + +class AOBaseConfig(abc.ABC): + """ + If a workflow config inherits from this then `quantize_` knows + how to a apply it to a model. For example:: + + # user facing code + class WorkflowFooConfig(AOBaseConfig): ... + # configuration for workflow `Foo` is defined here + bar = 'baz' + + # non user facing code + @register_quantize_module_handler(WorkflowFooConfig) + def _transform( + mod: torch.nn.Module, + config: WorkflowFooConfig, + ) -> torch.nn.Module: + # the transform is implemented here, usually a tensor sublass + # weight swap or a module swap + ... + + # then, the user calls `quantize_` with a config, and `_transform` is called + # under the hood by `quantize_. + + """ + + pass diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index aa4a51d497..71e8de337a 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. + from torchao.kernel import ( int_scaled_matmul, safe_int_mm, @@ -45,6 +46,7 @@ AffineQuantizedObserverBase, ) from .quant_api import ( + Int4WeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, @@ -85,6 +87,7 @@ swap_linear_with_smooth_fq_linear, ) from .subclass import * # noqa: F403 +from .transform_module import register_quantize_module_handler from .unified import Quantizer, TwoStepQuantizer from .utils import ( compute_error, @@ -117,6 +120,7 @@ "fpx_weight_only", "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", + "Int4WeightOnlyConfig", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", @@ -144,6 +148,8 @@ # operators/kernels "safe_int_mm", "int_scaled_matmul", + # registration of module transforms for quantize_ + "register_quantize_module_handler", # dataclasses and types "MappingType", "ZeroPointDomain", diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 15008e03ea..5dc3d8e008 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -1,6 +1,8 @@ from .api import ( ComposableQATQuantizer, FakeQuantizeConfig, + FromIntXQuantizationAwareTrainingConfig, + IntXQuantizationAwareTrainingConfig, from_intx_quantization_aware_training, intx_quantization_aware_training, ) @@ -20,4 +22,6 @@ "Int8DynActInt4WeightQATQuantizer", "intx_quantization_aware_training", "from_intx_quantization_aware_training", + "FromIntXQuantizationAwareTrainingConfig", + "IntXQuantizationAwareTrainingConfig", ] diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 925a0eed3c..d7e8f204cc 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,10 +5,11 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Union +from typing import Any, List, Optional, Union import torch +from torchao.core.config import AOBaseConfig from torchao.quantization.granularity import ( Granularity, PerAxis, @@ -22,6 +23,9 @@ TorchAODType, ZeroPointDomain, ) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.unified import TwoStepQuantizer @@ -241,12 +245,26 @@ def __setattr__(self, name: str, value: Any): super().__setattr__(name, value) -def intx_quantization_aware_training( - activation_config: Optional[FakeQuantizeConfig] = None, - weight_config: Optional[FakeQuantizeConfig] = None, -) -> Callable: +@dataclass +class IntXQuantizationAwareTrainingConfig(AOBaseConfig): + activation_config: Optional[FakeQuantizeConfig] = None + weight_config: Optional[FakeQuantizeConfig] = None + + +# for BC +intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig + + +@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig) +def _intx_quantization_aware_training_transform( + module: torch.nn.Module, + config: IntXQuantizationAwareTrainingConfig, +) -> torch.nn.Module: """ - Return a function that applies fake quantization to a `torch.nn.Module`. + THIS IS NOT A PUBLIC API - any usage of this outside of torchao + can break at any time. + + Apply fake quantization to a `torch.nn.Module`. to be used with :func:`~torchao.quantization.quant_api.quantize_`. Example usage:: @@ -261,7 +279,7 @@ def intx_quantization_aware_training( ) quantize_( model, - intx_quantization_aware_training(activation_config, weight_config), + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) Note: If the returned function is applied on a module that is not @@ -269,37 +287,32 @@ def intx_quantization_aware_training( `torch.nn.Embedding` with an activation config, then we will raise ValueError as these are not supported. """ - - def _insert_fake_quantize(mod: torch.nn.Module): - """ - Swap the given module with its corresponding fake quantized version. - """ - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear - - if isinstance(mod, torch.nn.Linear): - return FakeQuantizedLinear.from_linear( - mod, - activation_config, - weight_config, - ) - elif isinstance(mod, torch.nn.Embedding): - if activation_config is not None: - raise ValueError( - "Activation fake quantization is not supported for embedding" - ) - return FakeQuantizedEmbedding.from_embedding(mod, weight_config) - else: + from .embedding import FakeQuantizedEmbedding + from .linear import FakeQuantizedLinear + + mod = module + activation_config = config.activation_config + weight_config = config.weight_config + + if isinstance(mod, torch.nn.Linear): + return FakeQuantizedLinear.from_linear( + mod, + activation_config, + weight_config, + ) + elif isinstance(mod, torch.nn.Embedding): + if activation_config is not None: raise ValueError( - "Module of type '%s' does not have QAT support" % type(mod) + "Activation fake quantization is not supported for embedding" ) + return FakeQuantizedEmbedding.from_embedding(mod, weight_config) + else: + raise ValueError("Module of type '%s' does not have QAT support" % type(mod)) - return _insert_fake_quantize - -def from_intx_quantization_aware_training() -> Callable: +class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): """ - Return a function that converts a model with fake quantized modules, + Object that knows how to convert a model with fake quantized modules, such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`, back to model with the original, corresponding modules without @@ -311,26 +324,35 @@ def from_intx_quantization_aware_training() -> Callable: from torchao.quantization import quantize_ quantize_( model_with_fake_quantized_linears, - from_intx_quantization_aware_training(), + FromIntXQuantizationAwareTrainingConfig(), ) """ - def _remove_fake_quantize(mod: torch.nn.Module): - """ - If the given module is a fake quantized module, return the original - corresponding version of the module without fake quantization. - """ - from .embedding import FakeQuantizedEmbedding - from .linear import FakeQuantizedLinear + pass + + +# for BC +from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig - if isinstance(mod, FakeQuantizedLinear): - return mod.to_linear() - elif isinstance(mod, FakeQuantizedEmbedding): - return mod.to_embedding() - else: - return mod - return _remove_fake_quantize +@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig) +def _from_intx_quantization_aware_training_transform( + mod: torch.nn.Module, + config: FromIntXQuantizationAwareTrainingConfig, +) -> torch.nn.Module: + """ + If the given module is a fake quantized module, return the original + corresponding version of the module without fake quantization. + """ + from .embedding import FakeQuantizedEmbedding + from .linear import FakeQuantizedLinear + + if isinstance(mod, FakeQuantizedLinear): + return mod.to_linear() + elif isinstance(mod, FakeQuantizedEmbedding): + return mod.to_embedding() + else: + return mod class ComposableQATQuantizer(TwoStepQuantizer): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9b7999449f..9f6599c177 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -18,13 +18,15 @@ import logging import types import warnings -from typing import Callable, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.utils.parametrize as parametrize import torchao +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( AffineQuantizedTensor, CutlassInt4PackedLayout, @@ -47,6 +49,10 @@ LinearActivationWeightObservedTensor, ) from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size +from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, + register_quantize_module_handler, +) from torchao.quantization.weight_tensor_linear_activation_quantization import ( to_weight_tensor_with_linear_activation_quantization_metadata, ) @@ -117,7 +123,6 @@ "Int8DynActInt4WeightGPTQQuantizer", ] -# update according to the support matrix LAYOUT_TO_ZERO_POINT_DOMAIN = { TensorCoreTiledLayout: [ZeroPointDomain.FLOAT], MarlinSparseLayout: [ZeroPointDomain.INT], @@ -228,6 +233,7 @@ def _replace_with_custom_fn_if_matches_filter( filter_fn, cur_fqn="", device=None, + extra_args: Optional[Tuple[Any, ...]] = (), ) -> None: """ Recursively replaces each child module in `model` with the result of `replacement_fn(child)` @@ -239,6 +245,7 @@ def _replace_with_custom_fn_if_matches_filter( filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace. cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None. + extra_args (Tuple[Any, ...], optional): optional extra args to pass to `replacement_fn`. Returns: None @@ -252,12 +259,18 @@ def _replace_with_custom_fn_if_matches_filter( if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization - model = replacement_fn(model) + model = replacement_fn(model, *extra_args) return model else: - for name, child in model.named_children(): + named_children_list = list(model.named_children()) + for name, child in named_children_list: new_child = _replace_with_custom_fn_if_matches_filter( - child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device + child, + replacement_fn, + filter_fn, + f"{cur_fqn}{name}.", + device, + extra_args, ) if new_child is not child: setattr(model, name, new_child) @@ -472,17 +485,17 @@ def insert_subclass(lin): def quantize_( model: torch.nn.Module, - apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], + config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, set_inductor_config: bool = True, device: Optional[torch.types.Device] = None, ): - """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace + """Convert the weight of linear modules in the model with `config`, model is modified inplace Args: model (torch.nn.Module): input model - apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor) - filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on + config (Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]]): either (1) a workflow configuration object or (2) a function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor). Note: (2) will be deleted in a future release. + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `config` on the weight of the module set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`. @@ -494,7 +507,7 @@ def quantize_( import torch.nn as nn from torchao import quantize_ - # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to + # quantize with some predefined `config` method that corresponds to # optimized execution paths or kernels (e.g. int4 tinygemm kernel) # also customizable with arguments # currently options are @@ -507,39 +520,36 @@ def quantize_( m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) quantize_(m, int4_weight_only(group_size=32)) - # 2. write your own new apply_tensor_subclass - # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor - # on weight - - from torchao.dtypes import to_affine_quantized_intx - - # weight only uint4 asymmetric groupwise quantization - groupsize = 32 - apply_weight_quant = lambda x: to_affine_quantized_intx( - x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6, - zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float") - - def apply_weight_quant_to_linear(linear): - linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False) - return linear - - # apply to modules under block0 submodule - def filter_fn(module: nn.Module, fqn: str) -> bool: - return isinstance(module, nn.Linear) - - m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - quantize_(m, apply_weight_quant_to_linear, filter_fn) - """ if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - _replace_with_custom_fn_if_matches_filter( - model, - apply_tensor_subclass, - _is_linear if filter_fn is None else filter_fn, - device=device, - ) + if isinstance(config, AOBaseConfig): + handler = _QUANTIZE_CONFIG_HANDLER[type(config)] + # for each linear in the model, apply the transform if filtering passes + _replace_with_custom_fn_if_matches_filter( + model, + handler, + _is_linear if filter_fn is None else filter_fn, + device=device, + extra_args=(config,), + ) + + else: + # old behavior, keep to avoid breaking BC + warnings.warn( + """Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/issues/1690 for instructions on how to pass in workflow configuration instead.""" + ) + + # make the variable name make sense + apply_tensor_subclass = config + + _replace_with_custom_fn_if_matches_filter( + model, + apply_tensor_subclass, + _is_linear if filter_fn is None else filter_fn, + device=device, + ) def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: @@ -741,14 +751,10 @@ def gemlite_uintx_weight_only( return _get_linear_subclass_inserter(apply_fn) -def int4_weight_only( - group_size=128, - layout=TensorCoreTiledLayout(inner_k_tiles=8), - use_hqq=False, - zero_point_domain=ZeroPointDomain.NONE, -): +@dataclass +class Int4WeightOnlyConfig(AOBaseConfig): """ - Applies uint4 weight-only asymmetric per-group quantization to linear layers, using + Configuration for applying uint4 weight-only asymmetric per-group quantization to linear layers, using "tensor_core_tiled" layout for speedup with tinygemm kernel Note: @@ -765,64 +771,90 @@ def int4_weight_only( size is more fine grained, choices are [256, 128, 64, 32] `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` `use_hqq`: whether to use hqq or default quantization mode, default is False - `zero_point_domain`: data type of zeros points, choices are [None(then the value is determined by the layout), ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] + `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] """ - def apply_int4_weight_only_quant(weight): - if weight.shape[-1] % group_size != 0: - logger.info( - f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" - ) - return weight + group_size: int = 128 + layout: Optional[TensorCoreTiledLayout] = TensorCoreTiledLayout(inner_k_tiles=8) + use_hqq: bool = False + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] - zero_point_dtype = ( - weight.dtype if isinstance(layout, Int4CPULayout) else torch.bfloat16 + +# for BC +# TODO maybe change other callsites +int4_weight_only = Int4WeightOnlyConfig + + +@register_quantize_module_handler(Int4WeightOnlyConfig) +def _int4_weight_only_transform( + module: torch.nn.Module, config: Int4WeightOnlyConfig +) -> torch.nn.Module: + # TODO(future PR): perhaps move this logic to a different file, to keep the API + # file clean of implementation details + + # for now, make these local variables to allow the rest of the function + # to be a direct copy-paste + weight = module.weight + group_size = config.group_size + layout = config.layout + use_hqq = config.use_hqq + zero_point_domain = config.zero_point_domain + + if weight.shape[-1] % group_size != 0: + logger.info( + f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}" ) + return module + + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + eps = 1e-6 + preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] + zero_point_dtype = ( + weight.dtype if isinstance(layout, Int4CPULayout) else torch.bfloat16 + ) - nonlocal zero_point_domain + # nonlocal zero_point_domain + assert ( + type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() + ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" + if zero_point_domain == ZeroPointDomain.NONE: + # the first value is the default one + zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] + else: assert ( - type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() - ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" - if zero_point_domain == ZeroPointDomain.NONE: - # the first value is the default one - zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] - else: - assert ( - zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] - ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" - - # Sparse Marlin only supports symmetric quantization. - # NOTE: If we start having lots of layouts that require different configurations, - # we should consider moving this logic somewhere else. - if isinstance(layout, MarlinSparseLayout): - mapping_type = MappingType.SYMMETRIC - assert ( - group_size == 128 or group_size == weight.shape[-1] - ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" + zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] + ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" - return to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - _layout=layout, - use_hqq=use_hqq, - ) + # Sparse Marlin only supports symmetric quantization. + # NOTE: If we start having lots of layouts that require different configurations, + # we should consider moving this logic somewhere else. + if isinstance(layout, MarlinSparseLayout): + mapping_type = MappingType.SYMMETRIC + assert ( + group_size == 128 or group_size == weight.shape[-1] + ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" - return _get_linear_subclass_inserter(apply_int4_weight_only_quant) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + _layout=layout, + use_hqq=use_hqq, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def int8_weight_only(group_size=None): diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py new file mode 100644 index 0000000000..96fc808863 --- /dev/null +++ b/torchao/quantization/transform_module.py @@ -0,0 +1,46 @@ +import functools +from typing import Callable, Dict + +import torch + +from torchao.core.config import AOBaseConfig + +_QUANTIZE_CONFIG_HANDLER: Dict[ + AOBaseConfig, + Callable[[torch.nn.Module, AOBaseConfig], torch.nn.Module], +] = {} + + +def register_quantize_module_handler(config_type): + """ + A decorator to register a transform function to map from a workflow + configuration (child of `AOBaseConfig`) to a function that transforms + a `torch.nn.Module` according to the specified configuration. + + For example:: + + # user facing code + class WorkflowFooConfig(AOBaseConfig): ... + # configuration for workflow `Foo` is defined here + bar = 'baz' + + # non user facing code + @register_quantize_module_handler(WorkflowFooConfig) + def _transform( + mod: torch.nn.Module, + config: WorkflowFooConfig, + ) -> torch.nn.Module: + # the transform is implemented here, usually a tensor sublass + # weight swap or a module swap + ... + + # then, the user calls `quantize_` with a config, and `_transform` is called + # under the hood by `quantize_. + + """ + + @functools.wraps(config_type) + def decorator(func): + _QUANTIZE_CONFIG_HANDLER[config_type] = func + + return decorator From 2e51872663f9a55b24c9e6e322f94b3da4b9741c Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 13 Feb 2025 16:24:34 -0800 Subject: [PATCH 075/115] config migration: float8* (#1694) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 14 +- test/quantization/test_quant_api.py | 41 ++++- torchao/quantization/__init__.py | 6 + torchao/quantization/quant_api.py | 236 ++++++++++++++++----------- 4 files changed, 198 insertions(+), 99 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 53ca470b04..d26f1d8e04 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -123,16 +123,24 @@ def test_weights_only(self, apply_quant): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("apply_quant", get_quantization_functions(False, False)) def test_to_device(self, apply_quant): + def _apply(module, config_or_subclass_inserter): + if isinstance(config_or_subclass_inserter, AOBaseConfig): + quantize_(module, config_or_subclass_inserter) + else: + # TODO(#1690): delete this once config migration is done + module = config_or_subclass_inserter(module) + return module + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(linear) + ql = _apply(linear, apply_quant) ql.to("cuda") linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(linear) + ql = _apply(linear, apply_quant) ql.to(device="cuda") linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(linear) + ql = _apply(linear, apply_quant) ql.cuda() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index acd9b50c5a..e0f6cb1ace 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -30,6 +30,9 @@ Quantizer, TwoStepQuantizer, _replace_with_custom_fn_if_matches_filter, + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -46,6 +49,7 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_89, unwrap_tensor_subclass, ) @@ -784,28 +788,55 @@ def test_int4wo_cpu(self, dtype, x_dim): assert "_weight_int4pack_mm_for_cpu" in code[0] assert "aten.mm.default" not in code[0] + # TODO(#1690): move to new config names @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_int4_weight_only_numerics(self): + @common_utils.parametrize( + "config", + [ + int4_weight_only(), + float8_weight_only(), + float8_dynamic_activation_float8_weight(), + float8_static_activation_float8_weight(scale=torch.tensor([1.0])), + ], + ) + def test_workflow_e2e_numerics(self, config): """ Simple test of e2e int4_weight_only workflow, comparing numerics to a bfloat16 baseline. """ + if ( + isinstance( + config, + ( + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + ), + ) + and not is_sm_at_least_89() + ): + return unittest.skip("requires CUDA capability 8.9 or greater") + + # scale has to be moved to cuda here because the parametrization init + # code happens before gating for cuda availability + if isinstance(config, float8_static_activation_float8_weight): + config.scale = config.scale.to("cuda") + # set up inputs x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() - m_int4_wo = copy.deepcopy(m_ref) + m_q = copy.deepcopy(m_ref) # quantize - quantize_(m_int4_wo, int4_weight_only()) + quantize_(m_q, config) with torch.no_grad(): y_ref = m_ref(x) - y_int4_wo = m_int4_wo(x) + y_q = m_q(x) - sqnr = compute_error(y_ref, y_int4_wo) + sqnr = compute_error(y_ref, y_q) assert sqnr >= 20, f"SQNR {sqnr} is too low" diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 71e8de337a..ca9a4141fc 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -46,6 +46,9 @@ AffineQuantizedObserverBase, ) from .quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + Float8StaticActivationFloat8WeightConfig, + Float8WeightOnlyConfig, Int4WeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, @@ -121,6 +124,9 @@ "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", "Int4WeightOnlyConfig", + "Float8WeightOnlyConfig", + "Float8DynamicActivationFloat8WeightConfig", + "Float8StaticActivationFloat8WeightConfig", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9f6599c177..6e5e043fb0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1030,30 +1030,43 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) -def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): +@dataclass +class Float8WeightOnlyConfig(AOBaseConfig): """ - Applies float8 weight-only symmetric per-channel quantization to linear layers. + Configuration for applying float8 weight-only symmetric per-channel quantization to linear layers. Args: weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. Note: The actual matmul will be computed in original precision of the weight tensor. - """ - from torchao.dtypes import to_affine_quantized_floatx - def apply_float8wo_quant(weight): - block_size = (1, weight.shape[1]) - return to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=None, - _layout=Float8Layout(mm_config=None), - ) + weight_dtype: torch.dtype = torch.float8_e4m3fn - return _get_linear_subclass_inserter(apply_float8wo_quant) + +# for BC +float8_weight_only = Float8WeightOnlyConfig + + +@register_quantize_module_handler(Float8WeightOnlyConfig) +def _float8_weight_only_transform( + module: torch.nn.Module, config: Float8WeightOnlyConfig +) -> torch.nn.Module: + from torchao.dtypes import to_affine_quantized_floatx + + weight = module.weight + block_size = (1, weight.shape[1]) + new_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=config.weight_dtype, + scale_dtype=None, + _layout=Float8Layout(mm_config=None), + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module _fp8_granularities = Union[PerTensor, PerRow] @@ -1170,16 +1183,10 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool: return is_compatible -def float8_dynamic_activation_float8_weight( - activation_dtype: torch.dtype = torch.float8_e4m3fn, - weight_dtype: torch.dtype = torch.float8_e4m3fn, - granularity: Optional[ - Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] - ] = None, - mm_config: Optional[Float8MMConfig] = None, -): +@dataclass +class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): """ - Applies float8 dynamic symmetric quantization to both activations and weights of linear layers. + Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers. Args: activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. @@ -1192,56 +1199,76 @@ def float8_dynamic_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ + + activation_dtype: torch.dtype = torch.float8_e4m3fn + weight_dtype: torch.dtype = torch.float8_e4m3fn + granularity: Optional[ + Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + ] = None + mm_config: Optional[Float8MMConfig] = None + + def __post_init__(self): + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=True) + + +# for bc +float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig + + +@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig) +def _float8_dynamic_activation_float8_weight_transform( + module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig +): assert ( is_sm_at_least_89() or is_MI300() ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - if mm_config is None: - mm_config = Float8MMConfig(use_fast_accum=True) - activation_granularity, weight_granularity = _normalize_granularity(granularity) + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + granularity = config.granularity + mm_config = config.mm_config + weight = module.weight - def apply_float8_dynamic_activation_quant(weight: torch.Tensor): - if not _fp8_mm_compat(weight): - return weight - if isinstance(weight_granularity, PerRow): - assert ( - weight.dtype == torch.bfloat16 - ), "PerRow quantization only works for bfloat16 precision input weight" + activation_granularity, weight_granularity = _normalize_granularity(granularity) - block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) + if not _fp8_mm_compat(weight): + # TODO(future PR): this should really throw an exception instead of silently + # not doing what the user asked + return module + if isinstance(weight_granularity, PerRow): + assert ( + weight.dtype == torch.bfloat16 + ), "PerRow quantization only works for bfloat16 precision input weight" + + block_size = get_block_size(weight.shape, weight_granularity) + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) - input_quant_func = _input_activation_quant_func_fp8 - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } - quantized_weight = to_linear_activation_quantized( - quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs - ) - return quantized_weight + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) - return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module -def float8_static_activation_float8_weight( - scale: torch.Tensor, - activation_dtype: torch.dtype = torch.float8_e4m3fn, - weight_dtype: torch.dtype = torch.float8_e4m3fn, - granularity: Optional[ - Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] - ] = None, - mm_config: Optional[Float8MMConfig] = None, -): +@dataclass +class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): """ - Applies float8 static symmetric quantization to + Configuration for applying float8 static symmetric quantization to Args: scale (torch.Tensor): The scale tensor for activation quantization. @@ -1249,47 +1276,74 @@ def float8_static_activation_float8_weight( weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ + + scale: torch.Tensor + activation_dtype: torch.dtype = torch.float8_e4m3fn + weight_dtype: torch.dtype = torch.float8_e4m3fn + granularity: Optional[ + Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + ] = None + mm_config: Optional[Float8MMConfig] = None + + def __post_init__(self): + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=True) + + +# for bc +float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig + + +@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) +def _float8_static_activation_float8_weight_transform( + module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig +): assert ( is_sm_at_least_89() or is_MI300() ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" - if mm_config is None: - mm_config = Float8MMConfig(use_fast_accum=True) + scale = config.scale + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + granularity = config.granularity + mm_config = config.mm_config + + weight = module.weight activation_granularity, weight_granularity = _normalize_granularity(granularity) assert isinstance( activation_granularity, PerTensor ), "Static quantization only supports PerTensor granularity" - def apply_float8_static_activation_quant(weight: torch.Tensor): - if not _fp8_mm_compat(weight): - return weight - block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) + if not _fp8_mm_compat(weight): + # TODO(future PR): this should really throw an exception instead of silently + # not doing what the user asked + return module + block_size = get_block_size(weight.shape, weight_granularity) + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) - input_quant_func = _input_activation_quant_func_fp8 - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } - - quantized_weight = ( - to_weight_tensor_with_linear_activation_quantization_metadata( - quantized_weight, - input_quant_func, - scale=scale, - zero_point=None, - quant_kwargs=input_quant_kwargs, - ) - ) - return quantized_weight + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } - return _get_linear_subclass_inserter(apply_float8_static_activation_quant) + quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata( + quantized_weight, + input_quant_func, + scale=scale, + zero_point=None, + quant_kwargs=input_quant_kwargs, + ) + + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): From 6fe41c282eeeb231a48225d0c751345571c5c07c Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 13 Feb 2025 16:25:39 -0800 Subject: [PATCH 076/115] config migration: int* (#1696) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- test/dtypes/test_affine_quantized.py | 1 + test/quantization/test_quant_api.py | 13 +- torchao/quantization/__init__.py | 8 + torchao/quantization/quant_api.py | 270 +++++++++++++++------------ 4 files changed, 173 insertions(+), 119 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index d26f1d8e04..616701f1e3 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -218,6 +218,7 @@ def test_flatten_unflatten(self, device, dtype): linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) if isinstance(apply_quant, AOBaseConfig): quantize_(linear, apply_quant) + ql = linear else: # TODO(#1690): delete this once config migration is done ql = apply_quant(linear) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e0f6cb1ace..4cb0ee3579 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -33,6 +33,7 @@ float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, + int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -50,6 +51,7 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_sm_at_least_89, + is_sm_at_least_90, unwrap_tensor_subclass, ) @@ -798,6 +800,10 @@ def test_int4wo_cpu(self, dtype, x_dim): float8_weight_only(), float8_dynamic_activation_float8_weight(), float8_static_activation_float8_weight(scale=torch.tensor([1.0])), + int4_dynamic_activation_int4_weight(), + int8_dynamic_activation_int8_weight(), + int8_dynamic_activation_int4_weight(), + int8_weight_only(), ], ) def test_workflow_e2e_numerics(self, config): @@ -816,6 +822,11 @@ def test_workflow_e2e_numerics(self, config): and not is_sm_at_least_89() ): return unittest.skip("requires CUDA capability 8.9 or greater") + elif ( + isinstance(config, int4_dynamic_activation_int4_weight) + and is_sm_at_least_90() + ): + return unittest.skip("only supported on CUDA capability 8.9, not greater") # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability @@ -837,7 +848,7 @@ def test_workflow_e2e_numerics(self, config): y_q = m_q(x) sqnr = compute_error(y_ref, y_q) - assert sqnr >= 20, f"SQNR {sqnr} is too low" + assert sqnr >= 16.5, f"SQNR {sqnr} is too low" class TestMultiTensorFlow(TestCase): diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index ca9a4141fc..a1d8bda058 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -49,7 +49,11 @@ Float8DynamicActivationFloat8WeightConfig, Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, @@ -123,7 +127,11 @@ "fpx_weight_only", "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", + "Int4DynamicActivationInt4WeightConfig", + "Int8DynamicActivationInt4WeightConfig", + "Int8DynamicActivationInt8WeightConfig", "Int4WeightOnlyConfig", + "Int8WeightOnlyConfig", "Float8WeightOnlyConfig", "Float8DynamicActivationFloat8WeightConfig", "Float8StaticActivationFloat8WeightConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6e5e043fb0..60ee0384c9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -43,6 +43,7 @@ to_affine_quantized_intx, to_marlinqqq_quantized_intx, ) +from torchao.dtypes.utils import Layout from torchao.float8.float8_linear import Float8Linear from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( @@ -590,18 +591,45 @@ def _int8_symm_per_token_quant(x: torch.Tensor) -> torch.Tensor: ) -def apply_int8_dynamic_activation_int4_weight_quant( - weight, - group_size=32, - layout=PlainLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.ASYMMETRIC, +@dataclass +class Int8DynamicActivationInt4WeightConfig(AOBaseConfig): + """Configuration for applying int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear + This is used to produce a model for executorch backend, but currently executorch did not + support lowering for the quantized model from this flow yet + + Args: + `group_size`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained + `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric + """ + + group_size: int = 32 + layout: Layout = PlainLayout() + mapping_type: MappingType = MappingType.SYMMETRIC + act_mapping_type: MappingType = MappingType.ASYMMETRIC + + +# for BC +int8_dynamic_activation_int4_weight = Int8DynamicActivationInt4WeightConfig + + +@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) +def _int8_dynamic_activation_int4_weight_transform( + module: torch.nn.Module, config: Int8DynamicActivationInt4WeightConfig ): - """This is defined here instead of local function to support serialization""" + group_size = config.group_size + layout = config.layout + mapping_type = config.mapping_type + act_mapping_type = config.act_mapping_type + + weight = module.weight + if group_size is None or group_size == -1: group_size = weight.shape[-1] if weight.shape[-1] % group_size != 0: - return weight + return module # weight settings block_size = (1, group_size) @@ -639,41 +667,39 @@ def apply_int8_dynamic_activation_int4_weight_quant( _layout=layout, ) weight = to_linear_activation_quantized(weight, input_quant_func) - return weight + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module -def int8_dynamic_activation_int4_weight( - group_size=32, - layout=PlainLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.ASYMMETRIC, -): - """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear - This is used to produce a model for executorch backend, but currently executorch did not - support lowering for the quantized model from this flow yet +@dataclass +class Int4DynamicActivationInt4WeightConfig(AOBaseConfig): + """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear Args: - `group_size`: parameter for quantization, controls the granularity of quantization, smaller - size is more fine grained `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric """ - return _get_linear_subclass_inserter( - apply_int8_dynamic_activation_int4_weight_quant, - group_size=group_size, - layout=layout, - mapping_type=mapping_type, - act_mapping_type=act_mapping_type, - ) + layout: Layout = CutlassInt4PackedLayout() + mapping_type: MappingType = MappingType.SYMMETRIC + act_mapping_type: MappingType = MappingType.SYMMETRIC + + +# for bc +int4_dynamic_activation_int4_weight = Int4DynamicActivationInt4WeightConfig + + +@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) +def _int4_dynamic_activation_int4_weight_transform( + module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig +) -> torch.nn.Module: + weight = module.weight + layout = config.layout + mapping_type = config.mapping_type + act_mapping_type = config.act_mapping_type -def apply_int4_dynamic_activation_int4_weight_quant( - weight: torch.Tensor, - layout=CutlassInt4PackedLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.SYMMETRIC, -): if not isinstance(layout, CutlassInt4PackedLayout): raise NotImplementedError( f"Only CutlassInt4PackedLayout layout is supported. Received {layout}." @@ -698,27 +724,9 @@ def apply_int4_dynamic_activation_int4_weight_quant( weight, _int4_symm_per_token_quant_cutlass, ) - return weight - - -def int4_dynamic_activation_int4_weight( - layout=CutlassInt4PackedLayout(), - mapping_type=MappingType.SYMMETRIC, - act_mapping_type=MappingType.SYMMETRIC, -): - """Applies int4 dynamic per token symmetric activation quantization and int4 per row weight symmetric quantization to linear - - Args: - `layout`: layout type for quantized weight tensor, only supports `MarlinQQQLayout()` and `CutlassInt4PackedLayout()` for now - `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric - `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric - """ - return _get_linear_subclass_inserter( - apply_int4_dynamic_activation_int4_weight_quant, - layout=layout, - mapping_type=mapping_type, - act_mapping_type=act_mapping_type, - ) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def gemlite_uintx_weight_only( @@ -857,29 +865,42 @@ def _int4_weight_only_transform( return module -def int8_weight_only(group_size=None): +@dataclass +class Int8WeightOnlyConfig(AOBaseConfig): """ - Applies int8 weight-only symmetric per-channel quantization to linear layers. + Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers. """ - def apply_int8wo_quant(weight, group_size=None): - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - if group_size is None: - group_size = weight.shape[1] - block_size = (1, group_size) - return to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - ) + group_size: Optional[int] = None + + +# for BC +int8_weight_only = Int8WeightOnlyConfig + + +@register_quantize_module_handler(Int8WeightOnlyConfig) +def _int8_weight_only_transform(module: torch.nn.Module, config: Int8WeightOnlyConfig): + group_size = config.group_size + weight = module.weight - return _get_linear_subclass_inserter(apply_int8wo_quant, group_size=group_size) + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + if group_size is None: + group_size = weight.shape[1] + block_size = (1, group_size) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor: @@ -958,63 +979,76 @@ def _int4_symm_per_token_quant_cutlass(x: torch.Tensor) -> torch.Tensor: ) -def int8_dynamic_activation_int8_weight( - layout=PlainLayout(), - act_mapping_type=MappingType.SYMMETRIC, - weight_only_decode=False, -): +@dataclass +class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): """ - Applies int8 dynamic symmetric per-token activation and int8 per-channel weight + Configuration for applying int8 dynamic symmetric per-token activation and int8 per-channel weight quantization to linear layers """ - def apply_int8_dynamic_activation_int8_weight_quant(weight): - in_features = weight.shape[1] - # int8 dynamic quantization only has benefit when in_feature > 16 - if in_features <= 16: - logger.info( - f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}" - f" because `in_feature` is <= 16: {in_features}" - ) - return weight + layout: Optional[Layout] = PlainLayout() + act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC + weight_only_decode: bool = False - # weight settings - mapping_type = MappingType.SYMMETRIC - weight_zero_point_domain = ZeroPointDomain.NONE - def get_weight_block_size(x): - return (1, x.shape[1]) +# for BC +int8_dynamic_activation_int8_weight = Int8DynamicActivationInt8WeightConfig - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - if weight_only_decode: - input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode - else: - # input settings - if act_mapping_type == MappingType.SYMMETRIC: - input_quant_func = _int8_symm_per_token_reduced_range_quant - else: - input_quant_func = _int8_asymm_per_token_quant +@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) +def _int8_dynamic_activation_int8_weight_transform( + module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig +) -> torch.nn.Module: + layout = config.layout + act_mapping_type = config.act_mapping_type + weight_only_decode = config.weight_only_decode - block_size = get_weight_block_size(weight) - weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - _layout=layout, - zero_point_domain=weight_zero_point_domain, + weight = module.weight + + in_features = weight.shape[1] + # int8 dynamic quantization only has benefit when in_feature > 16 + if in_features <= 16: + logger.info( + f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}" + f" because `in_feature` is <= 16: {in_features}" ) - weight = to_linear_activation_quantized(weight, input_quant_func) - return weight + return module + + # weight settings + mapping_type = MappingType.SYMMETRIC + weight_zero_point_domain = ZeroPointDomain.NONE - return _get_linear_subclass_inserter( - apply_int8_dynamic_activation_int8_weight_quant + def get_weight_block_size(x): + return (1, x.shape[1]) + + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + if weight_only_decode: + input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode + else: + # input settings + if act_mapping_type == MappingType.SYMMETRIC: + input_quant_func = _int8_symm_per_token_reduced_range_quant + else: + input_quant_func = _int8_asymm_per_token_quant + + block_size = get_weight_block_size(weight) + weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + _layout=layout, + zero_point_domain=weight_zero_point_domain, ) + weight = to_linear_activation_quantized(weight, input_quant_func) + module.weight = torch.nn.Parameter(weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def int8_dynamic_activation_int8_semi_sparse_weight(): From 413689db50d86a29d4250b51583cc410c3ee5196 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 13 Feb 2025 16:26:43 -0800 Subject: [PATCH 077/115] config migration: fpx, gemlite, uintx (#1697) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- test/dtypes/test_uintx.py | 6 +- test/hqq/test_hqq_affine.py | 8 +- test/quantization/test_quant_api.py | 23 +++- torchao/quantization/__init__.py | 6 + torchao/quantization/quant_api.py | 189 ++++++++++++++++++---------- 5 files changed, 156 insertions(+), 76 deletions(-) diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index da43253678..9bc983885e 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -150,7 +150,7 @@ def test_uintx_target_dtype(dtype): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - uintx_weight_only(dtype)(linear) + quantize_(linear, uintx_weight_only(dtype)) linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @@ -165,7 +165,7 @@ def test_uintx_target_dtype_compile(dtype): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") # make sure it runs - uintx_weight_only(dtype)(linear) + quantize_(linear, uintx_weight_only(dtype)) linear = torch.compile(linear) linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")) @@ -196,6 +196,6 @@ def test_uintx_model_size(dtype): ) bf16_size = get_model_size_in_bytes(linear) # make sure it runs - uintx_weight_only(dtype)(linear[0]) + quantize_(linear[0], uintx_weight_only(dtype)) quantized_size = get_model_size_in_bytes(linear) assert bf16_size * _dtype_to_ratio[dtype] == quantized_size diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 096c9d26ba..d18ff59f99 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -53,12 +53,10 @@ def _eval_hqq(dtype): dummy_linear.weight.data = W if dtype == torch.uint4: config = int4_weight_only(group_size=max(block_size), use_hqq=True) - quantize_(dummy_linear, config) - q_tensor_hqq = dummy_linear.weight else: - q_tensor_hqq = uintx_weight_only( - dtype, group_size=max(block_size), use_hqq=True - )(dummy_linear).weight + config = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True) + quantize_(dummy_linear, config) + q_tensor_hqq = dummy_linear.weight quant_linear_layer = torch.nn.Linear( W.shape[1], W.shape[0], bias=False, device=W.device diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 4cb0ee3579..a53f47ac14 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -33,11 +33,14 @@ float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, + fpx_weight_only, + gemlite_uintx_weight_only, int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_weight_only, + uintx_weight_only, ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.subclass import ( @@ -55,6 +58,13 @@ unwrap_tensor_subclass, ) +try: + import gemlite # noqa: F401 + + has_gemlite = True +except ModuleNotFoundError: + has_gemlite = False + def dynamic_quant(model, example_inputs): m = torch.export.export(model, example_inputs, strict=True).module() @@ -804,6 +814,9 @@ def test_int4wo_cpu(self, dtype, x_dim): int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int4_weight(), int8_weight_only(), + fpx_weight_only(ebits=4, mbits=3), + gemlite_uintx_weight_only(), + uintx_weight_only(dtype=torch.uint4), ], ) def test_workflow_e2e_numerics(self, config): @@ -827,17 +840,23 @@ def test_workflow_e2e_numerics(self, config): and is_sm_at_least_90() ): return unittest.skip("only supported on CUDA capability 8.9, not greater") + elif isinstance(config, gemlite_uintx_weight_only) and not has_gemlite: + return unittest.skip("gemlite not available") # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability if isinstance(config, float8_static_activation_float8_weight): config.scale = config.scale.to("cuda") + dtype = torch.bfloat16 + if isinstance(config, gemlite_uintx_weight_only): + dtype = torch.float16 + # set up inputs - x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + x = torch.randn(128, 128, device="cuda", dtype=dtype) # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? - m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype) m_q = copy.deepcopy(m_ref) # quantize diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index a1d8bda058..5f15a6bbbe 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -49,11 +49,14 @@ Float8DynamicActivationFloat8WeightConfig, Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, + FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, + UIntXWeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, @@ -135,6 +138,9 @@ "Float8WeightOnlyConfig", "Float8DynamicActivationFloat8WeightConfig", "Float8StaticActivationFloat8WeightConfig", + "UIntXWeightOnlyConfig", + "FPXWeightOnlyConfig", + "GemliteUIntXWeightOnlyConfig", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 60ee0384c9..e347529929 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -729,12 +729,8 @@ def _int4_dynamic_activation_int4_weight_transform( return module -def gemlite_uintx_weight_only( - group_size: Optional[int] = 64, - bit_width: int = 4, - packing_bitwidth: int = 32, - contiguous: Optional[bool] = None, -): +@dataclass +class GemliteUIntXWeightOnlyConfig(AOBaseConfig): """ applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format. This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric. @@ -747,16 +743,39 @@ def gemlite_uintx_weight_only( `contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice. """ + group_size: Optional[int] = 64 + bit_width: int = 4 + packing_bitwidth: int = 32 + contiguous: Optional[bool] = None + + +# for BC +gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig + + +@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) +def _gemlite_uintx_weight_only_transform( + module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig +): + group_size = config.group_size + bit_width = config.bit_width + packing_bitwidth = config.packing_bitwidth + contiguous = config.contiguous + + weight = module.weight + from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs use_hqq = True if bit_width == 4 else False - apply_fn = lambda weight: to_affine_quantized_intx( + new_weight = to_affine_quantized_intx( weight, **get_gemlite_aqt_kwargs( weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq ), ) - return _get_linear_subclass_inserter(apply_fn) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module @dataclass @@ -1380,9 +1399,10 @@ def _float8_static_activation_float8_weight_transform( return module -def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): +@dataclass +class UIntXWeightOnlyConfig(AOBaseConfig): """ - Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where + Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where x is the number of bits specified by `dtype` Args: @@ -1392,6 +1412,28 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): `pack_dim`: the dimension we use for packing, defaults to -1 `use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight """ + + dtype: torch.dtype + group_size: int = 64 + pack_dim: int = -1 + use_hqq: bool = False + + +# for BC +uintx_weight_only = UIntXWeightOnlyConfig + + +@register_quantize_module_handler(UIntXWeightOnlyConfig) +def _uintx_weight_only_transform( + module: torch.nn.Module, config: UIntXWeightOnlyConfig +): + dtype = config.dtype + group_size = config.group_size + pack_dim = config.pack_dim + use_hqq = config.use_hqq + + weight = module.weight + from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS SUPPORTED_DTYPES = { @@ -1406,49 +1448,50 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): } assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}" - def apply_uintx_weight_only_quant(weight, dtype): - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - - if use_hqq: - if dtype == torch.uint4: - logger.warn( - "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance" - ) - quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] - dtype = torch.uint8 - eps = None - zero_point_dtype = None - zero_point_domain = ZeroPointDomain.FLOAT - preserve_zero = False - _layout = PlainLayout() - else: - quant_min, quant_max = None, None - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int32 - zero_point_domain = ZeroPointDomain.INT - preserve_zero = True - _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) - return to_affine_quantized_intx( - weight, - mapping_type, - block_size, - dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - zero_point_dtype=zero_point_dtype, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - _layout=_layout, - use_hqq=use_hqq, - ) + if use_hqq: + if dtype == torch.uint4: + logger.warn( + "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance" + ) + quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] + dtype = torch.uint8 + eps = None + zero_point_dtype = None + zero_point_domain = ZeroPointDomain.FLOAT + preserve_zero = False + _layout = PlainLayout() + else: + quant_min, quant_max = None, None + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int32 + zero_point_domain = ZeroPointDomain.INT + preserve_zero = True + _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) - return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + zero_point_dtype=zero_point_dtype, + zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, + _layout=_layout, + use_hqq=use_hqq, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module -def fpx_weight_only(ebits: int, mbits: int): +@dataclass +class FPXWeightOnlyConfig(AOBaseConfig): """Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits e.g. fp6_e3_m2, fp6_e2_m3, ... The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112 @@ -1459,26 +1502,40 @@ def fpx_weight_only(ebits: int, mbits: int): in the future """ - def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: - from torchao.dtypes import to_affine_quantized_fpx - from torchao.dtypes.floatx import FloatxTensorCoreLayout + ebits: int + mbits: int - assert ( - weight.dim() == 2 - ), f"floatx only works for 2-d Tensor, got: {weight.dim()}" - out_dim, in_dim = weight.shape - if (in_dim % 64 != 0) or (out_dim % 256 != 0): - logger.info( - f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because " - f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} " - "expected in_dim % 64 == 0 and out_dim % 256 == 0" - ) - return weight - _layout = FloatxTensorCoreLayout(ebits, mbits) - return to_affine_quantized_fpx(weight, _layout) +# for BC +fpx_weight_only = FPXWeightOnlyConfig + + +@register_quantize_module_handler(FPXWeightOnlyConfig) +def _fpx_weight_only_transform( + module: torch.nn.Module, config: FPXWeightOnlyConfig +) -> torch.nn.Module: + ebits = config.ebits + mbits = config.mbits + weight = module.weight + + from torchao.dtypes import to_affine_quantized_fpx + from torchao.dtypes.floatx import FloatxTensorCoreLayout - return _get_linear_subclass_inserter(apply_quant_llm) + assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}" + out_dim, in_dim = weight.shape + if (in_dim % 64 != 0) or (out_dim % 256 != 0): + logger.info( + f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because " + f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} " + "expected in_dim % 64 == 0 and out_dim % 256 == 0" + ) + return module + + _layout = FloatxTensorCoreLayout(ebits, mbits) + new_weight = to_affine_quantized_fpx(weight, _layout) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module if TORCH_VERSION_AT_LEAST_2_5: From 17b9ce3586b46a8e4eb7561d0a17b3fe7a07f6f2 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 13 Feb 2025 16:27:44 -0800 Subject: [PATCH 078/115] unbreak float8 static quant tutorial (#1709) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- torchao/dtypes/floatx/float8_layout.py | 1 + tutorials/calibration_flow/static_quant.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 5a7e1924b3..656ebb61ae 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -253,6 +253,7 @@ def _linear_fp8_act_fp8_weight_impl( ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" scaled_mm_config = weight_tensor._layout.mm_config + assert scaled_mm_config is not None out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) # Weight tensor preprocessing diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index 4b7dfe405f..fd24a71189 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -163,12 +163,13 @@ def __init__( weight, weight_scale, weight_zero_point, block_size, self.target_dtype ) elif self.target_dtype == torch.float8_e4m3fn: + mm_config = Float8MMConfig(use_fast_accum=True) self.qweight = to_affine_quantized_floatx_static( weight, weight_scale, block_size, target_dtype, - Float8Layout(mm_config=None), + Float8Layout(mm_config=mm_config), ) else: raise ValueError(f"Unsupported target dtype {self.target_dtype}") From 3fa8e4442c38522f9339b9cbb64fb244a9a1b153 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 13 Feb 2025 16:28:40 -0800 Subject: [PATCH 079/115] migrate static quant tutorials to direct configuration (#1710) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- tutorials/calibration_flow/awq_like.py | 114 ++++++++++-------- tutorials/calibration_flow/gptq_like.py | 66 ++++++----- tutorials/calibration_flow/static_quant.py | 131 ++++++++++++--------- 3 files changed, 178 insertions(+), 133 deletions(-) diff --git a/tutorials/calibration_flow/awq_like.py b/tutorials/calibration_flow/awq_like.py index 5742b9b328..c047b8531e 100644 --- a/tutorials/calibration_flow/awq_like.py +++ b/tutorials/calibration_flow/awq_like.py @@ -8,11 +8,13 @@ """ import copy +from dataclasses import dataclass import torch import torch.nn.functional as F from torch import Tensor +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( Float8Layout, to_affine_quantized_floatx_static, @@ -33,6 +35,9 @@ from torchao.quantization.quant_primitives import ( MappingType, ) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.utils import compute_error @@ -83,61 +88,72 @@ def replacement_fn(m): _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) +@dataclass +class ApplyAWQConfig(AOBaseConfig): + target_dtype: torch.dtype + + # converting observed linear module to linear module with quantzied weights (and quantized activations) # with tensor subclasses -def apply_awq(target_dtype: torch.dtype): - # target_dtype = torch.uint8 - def _apply_awq_to_linear(observed_linear): - # weight quantization - weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() - - def weight_quant_func(weight): - block_size = (1, weight.shape[1]) - if target_dtype == torch.uint8: - return to_affine_quantized_intx_static( - weight, weight_scale, weight_zero_point, block_size, target_dtype - ) - elif target_dtype == torch.float8_e4m3fn: - return to_affine_quantized_floatx_static( - weight, - weight_scale, - block_size, - target_dtype, - Float8Layout(mm_config=None), - ) - else: - raise ValueError(f"Unsupported target dtype {target_dtype}") - - linear = torch.nn.Linear( - observed_linear.in_features, - observed_linear.out_features, - False, - device=observed_linear.weight.device, - dtype=observed_linear.weight.dtype, - ) - linear.weight = observed_linear.weight - linear.bias = observed_linear.bias - # activation quantization - # pretend this to be the equalization scale, in reality the `act_obs` should - # be an observer that can caluclate equalization scale - equalization_scale, _ = observed_linear.act_obs.calculate_qparams() - equalization_scale = torch.ones_like(equalization_scale) - linear.weight = torch.nn.Parameter( - weight_quant_func(linear.weight * equalization_scale), requires_grad=False - ) +@register_quantize_module_handler(ApplyAWQConfig) +def _apply_awq_transform( + module: torch.nn.Module, + config: ApplyAWQConfig, +): + target_dtype = config.target_dtype + observed_linear = module - linear.weight = torch.nn.Parameter( - to_weight_tensor_with_linear_activation_scale_metadata( - linear.weight, equalization_scale - ), - requires_grad=False, - ) + # target_dtype = torch.uint8 + # weight quantization + weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() + + def weight_quant_func(weight): + block_size = (1, weight.shape[1]) + if target_dtype == torch.uint8: + return to_affine_quantized_intx_static( + weight, weight_scale, weight_zero_point, block_size, target_dtype + ) + elif target_dtype == torch.float8_e4m3fn: + return to_affine_quantized_floatx_static( + weight, + weight_scale, + block_size, + target_dtype, + Float8Layout(mm_config=None), + ) + else: + raise ValueError(f"Unsupported target dtype {target_dtype}") + + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + False, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) + linear.weight = observed_linear.weight + linear.bias = observed_linear.bias + + # activation quantization + # pretend this to be the equalization scale, in reality the `act_obs` should + # be an observer that can caluclate equalization scale + equalization_scale, _ = observed_linear.act_obs.calculate_qparams() + equalization_scale = torch.ones_like(equalization_scale) - return linear + linear.weight = torch.nn.Parameter( + weight_quant_func(linear.weight * equalization_scale), requires_grad=False + ) + + linear.weight = torch.nn.Parameter( + to_weight_tensor_with_linear_activation_scale_metadata( + linear.weight, equalization_scale + ), + requires_grad=False, + ) - return _apply_awq_to_linear + return linear ######## Test ########## @@ -201,7 +217,7 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType): # quantized linear represented as an nn.Linear with modified tensor subclass weights # for both activation and weight quantization - quantize_(m, apply_awq(target_dtype), is_observed_linear) + quantize_(m, ApplyAWQConfig(target_dtype), is_observed_linear) print("quantized model (applying tensor subclass to weight):", m) after_quant = m(*example_inputs) assert compute_error(before_quant, after_quant) > 25 diff --git a/tutorials/calibration_flow/gptq_like.py b/tutorials/calibration_flow/gptq_like.py index 93c7e3c4ab..e4f28faf6f 100644 --- a/tutorials/calibration_flow/gptq_like.py +++ b/tutorials/calibration_flow/gptq_like.py @@ -33,6 +33,7 @@ import torch from torch.utils._pytree import tree_flatten, tree_unflatten +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( to_affine_quantized_intx, to_affine_quantized_intx_static, @@ -47,6 +48,9 @@ to_linear_activation_quantized, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.utils import compute_error torch.manual_seed(0) @@ -252,36 +256,42 @@ def _register_forward_pre_hook(module: torch.nn.Module): ) -# using a function to align with the API in quant_api -def apply_activation_static_weight_quant(): - def _apply_activation_static_weight_quant(observed_linear): - target_dtype = torch.uint8 - - # we can quantize the weight here as well +class ApplyActivationStaticWeightQuantConfig(AOBaseConfig): + pass - # activation quantization - act_scale, act_zero_point = ( - observed_linear.input_scale, - observed_linear.input_zp, - ) - input_quant_func = lambda x: to_affine_quantized_intx_static( - x, act_scale, act_zero_point, x.shape, target_dtype - ) - # for demo purpose only, we quantize the weight here - weight = observed_linear.weight - weight = to_affine_quantized_intx( - weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8 - ) - observed_linear.weight = torch.nn.Parameter( - to_linear_activation_quantized(weight, input_quant_func), - requires_grad=False, - ) - del observed_linear.input_scale - del observed_linear.input_zp - return observed_linear +# using a function to align with the API in quant_api +@register_quantize_module_handler(ApplyActivationStaticWeightQuantConfig) +def _apply_activation_static_weight_quant_transform( + module: torch.nn.Module, + config: ApplyActivationStaticWeightQuantConfig, +): + observed_linear = module + target_dtype = torch.uint8 + + # we can quantize the weight here as well + + # activation quantization + act_scale, act_zero_point = ( + observed_linear.input_scale, + observed_linear.input_zp, + ) + input_quant_func = lambda x: to_affine_quantized_intx_static( + x, act_scale, act_zero_point, x.shape, target_dtype + ) + # for demo purpose only, we quantize the weight here + weight = observed_linear.weight + weight = to_affine_quantized_intx( + weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8 + ) + observed_linear.weight = torch.nn.Parameter( + to_linear_activation_quantized(weight, input_quant_func), + requires_grad=False, + ) - return _apply_activation_static_weight_quant + del observed_linear.input_scale + del observed_linear.input_zp + return observed_linear example_inputs = (torch.randn(32, 64),) @@ -298,7 +308,7 @@ def _apply_activation_static_weight_quant(observed_linear): # just quantizing activation since we only observed quantization, this could be extended to support # quantizing weight as well -quantize_(m, apply_activation_static_weight_quant(), _is_linear) +quantize_(m, ApplyActivationStaticWeightQuantConfig(), _is_linear) for l in m.modules(): if isinstance(l, torch.nn.Linear): assert isinstance(l.weight, LinearActivationQuantizedTensor) diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index fd24a71189..1ebce411d3 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -3,11 +3,13 @@ """ import copy +from dataclasses import dataclass import torch import torch.nn.functional as F from torch import Tensor +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( Float8Layout, to_affine_quantized_floatx_static, @@ -26,6 +28,9 @@ from torchao.quantization.quant_primitives import ( MappingType, ) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.utils import compute_error from torchao.utils import is_sm_at_least_90 @@ -77,66 +82,74 @@ def replacement_fn(m): _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) -# converting observed linear module to linear module with quantzied weights (and quantized activations) -# with tensor subclasses -def apply_static_quant(target_dtype: torch.dtype): - # target_dtype = torch.uint8 - def _apply_static_quant_to_linear(observed_linear): - # weight quantization - weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() - - def weight_quant_func(weight): - block_size = (1, weight.shape[1]) - if target_dtype == torch.uint8: - return to_affine_quantized_intx_static( - weight, weight_scale, weight_zero_point, block_size, target_dtype - ) - elif target_dtype == torch.float8_e4m3fn: - mm_config = Float8MMConfig(use_fast_accum=True) - return to_affine_quantized_floatx_static( - weight, - weight_scale, - block_size, - target_dtype, - Float8Layout(mm_config=mm_config), - ) - else: - raise ValueError(f"Unsupported target dtype {target_dtype}") - - linear = torch.nn.Linear( - observed_linear.in_features, - observed_linear.out_features, - False, - device=observed_linear.weight.device, - dtype=observed_linear.weight.dtype, - ) - linear.weight = observed_linear.weight - linear.bias = observed_linear.bias +@dataclass +class ApplyStaticQuantConfig(AOBaseConfig): + target_dtype: torch.dtype - linear.weight = torch.nn.Parameter( - weight_quant_func(linear.weight), requires_grad=False - ) - # activation quantization - act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams() +# converting observed linear module to linear module with quantzied weights (and quantized activations) +# with tensor subclasses +@register_quantize_module_handler(ApplyStaticQuantConfig) +def _apply_static_quant_transform( + module: torch.nn.Module, + config: ApplyStaticQuantConfig, +): + target_dtype = config.target_dtype + observed_linear = module + + # weight quantization + weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() + + def weight_quant_func(weight): + block_size = (1, weight.shape[1]) if target_dtype == torch.uint8: - input_quant_func = lambda x: to_affine_quantized_intx_static( - x, act_scale, act_zero_point, x.shape, target_dtype + return to_affine_quantized_intx_static( + weight, weight_scale, weight_zero_point, block_size, target_dtype ) elif target_dtype == torch.float8_e4m3fn: - input_quant_func = lambda x: to_affine_quantized_floatx_static( - x, act_scale, x.shape, target_dtype, Float8Layout(mm_config=None) + mm_config = Float8MMConfig(use_fast_accum=True) + return to_affine_quantized_floatx_static( + weight, + weight_scale, + block_size, + target_dtype, + Float8Layout(mm_config=mm_config), ) else: raise ValueError(f"Unsupported target dtype {target_dtype}") - linear.weight = torch.nn.Parameter( - to_linear_activation_quantized(linear.weight, input_quant_func), - requires_grad=False, - ) - return linear + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + False, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) + linear.weight = observed_linear.weight + linear.bias = observed_linear.bias - return _apply_static_quant_to_linear + linear.weight = torch.nn.Parameter( + weight_quant_func(linear.weight), requires_grad=False + ) + + # activation quantization + act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams() + if target_dtype == torch.uint8: + input_quant_func = lambda x: to_affine_quantized_intx_static( + x, act_scale, act_zero_point, x.shape, target_dtype + ) + elif target_dtype == torch.float8_e4m3fn: + input_quant_func = lambda x: to_affine_quantized_floatx_static( + x, act_scale, x.shape, target_dtype, Float8Layout(mm_config=None) + ) + else: + raise ValueError(f"Unsupported target dtype {target_dtype}") + linear.weight = torch.nn.Parameter( + to_linear_activation_quantized(linear.weight, input_quant_func), + requires_grad=False, + ) + + return linear # alternative for converting observed linear module to quantized linear module @@ -210,11 +223,17 @@ def from_observed(cls, observed_linear, target_dtype): return quantized_linear -def apply_static_quant2(target_dtype: torch.dtype): - def _apply_static_quant2(observed_linear): - return QuantizedLinear.from_observed(observed_linear, target_dtype) +@dataclass +class ApplyStaticQuantConfig2(AOBaseConfig): + target_dtype: torch.dtype + - return _apply_static_quant2 +@register_quantize_module_handler(ApplyStaticQuantConfig2) +def apply_static_quant( + module: torch.nn.Module, + config: ApplyStaticQuantConfig2, +): + return QuantizedLinear.from_observed(module, config.target_dtype) class ToyLinearModel(torch.nn.Module): @@ -281,14 +300,14 @@ def test_static_quant(target_dtype: torch.dtype, mapping_type: MappingType): # quantized linear represented as an nn.Linear with modified tensor subclass weights # for both activation and weight quantization - quantize_(m, apply_static_quant(target_dtype), is_observed_linear) + quantize_(m, ApplyStaticQuantConfig(target_dtype), is_observed_linear) print("quantized model (applying tensor subclass to weight):", m) after_quant = m(*example_inputs) assert compute_error(before_quant, after_quant) > 25 print("test passed") # quantized linear as a standalone module - quantize_(m2, apply_static_quant2(target_dtype), is_observed_linear) + quantize_(m2, ApplyStaticQuantConfig2(target_dtype), is_observed_linear) print("quantized model (quantized module):", m2) after_quant = m2(*example_inputs) assert compute_error(before_quant, after_quant) > 25 From 12e830b49fb997de2ecd4a986f12df76d6442e64 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 13 Feb 2025 16:29:40 -0800 Subject: [PATCH 080/115] update torchao READMEs with new configuration APIs (#1711) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- README.md | 26 +++++++++--------- torchao/quantization/README.md | 44 +++++++++++++++--------------- torchao/quantization/qat/README.md | 18 ++++++------ 3 files changed, 44 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 0da273f91c..e3cdc60aba 100644 --- a/README.md +++ b/README.md @@ -29,16 +29,16 @@ For inference, we have the option of ```python from torchao.quantization.quant_api import ( quantize_, - int8_dynamic_activation_int8_weight, - int4_weight_only, - int8_weight_only + Int8DynamicActivationInt8WeightConfig, + Int4WeightOnlyConfig, + Int8WeightOnlyConfig ) -quantize_(m, int4_weight_only()) +quantize_(m, Int4WeightOnlyConfig()) ``` -For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline. +For gpt-fast `Int4WeightOnlyConfig()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline. -If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, int8_weight_only(), device="cuda")` which will send and quantize each layer individually to your GPU. +If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, Int8WeightOnlyConfig(), device="cuda")` which will send and quantize each layer individually to your GPU. If you see slowdowns with any of these techniques or you're unsure which option to use, consider using [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers and pick the best way to quantize each layer. @@ -63,12 +63,12 @@ Post-training quantization can result in a fast and compact model, but may also ```python from torchao.quantization import ( quantize_, - int8_dynamic_activation_int4_weight, + Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.qat import ( FakeQuantizeConfig, - from_intx_quantization_aware_training, - intx_quantization_aware_training, + FromIntXQuantizationAwareTrainingConfig, + IntXQuantizationAwareTrainingConfig, ) # Insert fake quantization @@ -76,14 +76,14 @@ activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=Fal weight_config = FakeQuantizeConfig(torch.int4, group_size=32) quantize_( my_model, - intx_quantization_aware_training(activation_config, weight_config), + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) # Run training... (not shown) # Convert fake quantization to actual quantized operations -quantize_(my_model, from_intx_quantization_aware_training()) -quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32)) +quantize_(my_model, FromIntXQuantizationAwareTrainingConfig()) +quantize_(my_model, Int8DynamicActivationInt4WeightConfig(group_size=32)) ``` ### Float8 @@ -139,7 +139,7 @@ The best example we have combining the composability of lower bit dtype with com We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow -1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))` +1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))` 2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256 3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index ace4d8c14c..655a942718 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -82,7 +82,7 @@ model(input) When used as in the example above, when the `autoquant` api is called alongside torch.compile, autoquant sets up the model so that when its run on the next input, the autoquantization and torch.compile processes leave you with a heavily optimized model. -When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `int4_weight_only()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow. +When `model(input)` is called, (under the hood) the tool does a preliminary run with the input where each linear layer keeps track of the different shapes and types of activations that it sees. Once the preliminary run is complete, the next step is to check each linear layer and benchmark the tracked shapes for different types of quantization techniques in order to pick the fastest one, attempting to take into account fusions where possible. Finally once the best class is found for each layer, the next step is to apply the necessary quantization technique to each layer, before finally allowing the normal `torch.compile` process to occur on the now quantized model. By default the api only uses int8 techniques, i.e. it chooses between no quantization, int8 dynamic quantization and int8 weight only quantization for each layer, though there is also an option add int4 quantization which can be used for maximum performance or to avoid perf regressions from `Int4WeightOnlyConfig()` since for certain (compute bound) regimes, int4 weight only quantization can be very slow. Sometimes it is desirable to reuse a quantization plan that `autoquant` came up with. `torchao.quantization.AUTOQUANT_CACHE` is a dictionary holding autoquant's benchmark results. We can save it and restore it later, which will cause `autoquant` to choose the same quantization methods. @@ -109,13 +109,13 @@ be applied individually. While there are a large variety of quantization apis, t ```python # for torch 2.4+ -from torchao.quantization import quantize_, int4_weight_only +from torchao.quantization import quantize_, Int4WeightOnlyConfig group_size = 32 # you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through -# use_hqq flag for `int4_weight_only` quantization +# use_hqq flag for `Int4WeightOnlyConfig` quantization use_hqq = False -quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) +quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors @@ -128,8 +128,8 @@ Note: The quantization error incurred by applying int4 quantization to your mode ```python # for torch 2.4+ -from torchao.quantization import quantize_, int8_weight_only -quantize_(model, int8_weight_only()) +from torchao.quantization import quantize_, Int8WeightOnlyConfig +quantize_(model, Int8WeightOnlyConfig()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors @@ -140,8 +140,8 @@ change_linear_weights_to_int8_woqtensors(model) ```python # for torch 2.4+ -from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight -quantize_(model, int8_dynamic_activation_int8_weight()) +from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig +quantize_(model, Int8DynamicActivationInt8WeightConfig()) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors @@ -152,8 +152,8 @@ change_linear_weights_to_int8_dqtensors(model) ```python # for torch 2.5+ -from torchao.quantization import quantize_, float8_weight_only -quantize_(model, float8_weight_only()) +from torchao.quantization import quantize_, Float8WeightOnlyConfig +quantize_(model, Float8WeightOnlyConfig()) ``` Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. @@ -162,8 +162,8 @@ Supports all dtypes for original weight and activation. This API is only tested ```python # for torch 2.4+ -from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, PerTensor -quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor())) +from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor +quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())) ``` Supports all dtypes for original weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. @@ -172,8 +172,8 @@ Supports all dtypes for original weight and activation. This API is only tested ```python # for torch 2.5+ -from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight -quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerRow())) +from torchao.quantization import quantize_, PerRow, Float8DynamicActivationFloat8WeightConfig +quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) ``` Per-row scaling is only supported for bfloat16 weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required. @@ -182,14 +182,14 @@ Per-row scaling is only supported for bfloat16 weight and activation. This API i ```python # for torch 2.4+ -from torchao.quantization import quantize_, fpx_weight_only -quantize_(model, fpx_weight_only(3, 2)) +from torchao.quantization import quantize_, FPXWeightOnlyConfig +quantize_(model, FPXWeightOnlyConfig(3, 2)) ``` You can find more information [here](../dtypes/floatx/README.md). It should be noted where most other TorchAO apis and benchmarks have focused on applying techniques on top of a bf16 model, performance, fp6 works primarily with the fp16 dtype. ## Affine Quantization Details -Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_preicsion_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization. +Affine quantization refers to the type of quantization that maps from high precision floating point numbers to quantized numbers (low precision integer or floating point dtypes) with an affine transformation, i.e.: `quantized_val = high_precision_float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data (also some dtypes may not require a `zero_point`). Each of the techniques in the above section qualify as Affine Quantization. ### Quantization Primitives We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass. @@ -200,7 +200,7 @@ Note: these primitive ops supports two "types" of quantization, distinguished by We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel) #### Layouts -We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for `int8_weight_only` and `int8_dynamic_activation_int8_weight` and also as a default layout. `tensor_core_tiled` layout is used for `int4_weight_only` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels. +We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for workflows backing `Int8WeightOnlyConfig` and `Int8DynamicActivationInt8WeightConfig` and also as a default layout. `tensor_core_tiled` layout is used for workflows backing `Int4WeightOnlyConfig` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels. ### Zero Point Domains ```ZeroPointDomain``` is used to control the data types of zero points. ```ZeroPointDomain.None``` means zero_point is None, ```ZeroPointDomain.FLOAT``` means zero_point is in the floating point domain and ```ZeroPointDomain.INT``` means integer domain. For detailed implementation of different zero point data types, refer to [the reference implementation](../../test/quantization/test_quant_primitives.py). @@ -223,7 +223,7 @@ from torchao.dtypes import to_affine_quantized_intx import copy from torchao.quantization.quant_api import ( quantize_, - int4_weight_only, + Int4WeightOnlyConfig, ) class ToyLinearModel(torch.nn.Module): @@ -249,9 +249,9 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') # apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao) group_size = 32 # only works for torch 2.4+ -quantize_(m, int4_weight_only(group_size=group_size)) +quantize_(m, Int4WeightOnlyConfig(group_size=group_size)) ## If different zero_point_domain needed -# quantize_(m, int4_weight_only(group_size=group_size), zero_point_domain=ZeroPointDomain.FLOAT) +# quantize_(m, Int4WeightOnlyConfig(group_size=group_size, zero_point_domain=ZeroPointDomain.FLOAT)) # temporary workaround for tensor subclass + torch.compile # NOTE: this is only need for torch version < 2.5+ @@ -360,7 +360,7 @@ We're trying to develop kernels for low bit quantization for intx quantization f | | uintx-4-64-hqq | 8.124 | 47.85 | 213.24 | 11.85 | 4.46 | | | uintx-2-8-hqq | 39.605 | 34.83 | 261.42 | 14.99 | 7.51 | -You try can out these apis with the `quantize_` api as above alongside the constructor `uintx_weight_only` an example can be found in in `torchao/_models/llama/generate.py`. +You try can out these apis with the `quantize_` api as above alongside the config `UIntXWeightOnlyConfig`. An example can be found in in `torchao/_models/llama/generate.py`. ### int8_dynamic_activation_intx_weight Quantization We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon). The benchmarks below were run on an M1 Mac Pro, with 8 perf cores, and 2 efficiency cores, and 32GB of RAM. In all cases, torch.compile was used. diff --git a/torchao/quantization/qat/README.md b/torchao/quantization/qat/README.md index 813b628af7..0f024dbf61 100644 --- a/torchao/quantization/qat/README.md +++ b/torchao/quantization/qat/README.md @@ -71,9 +71,9 @@ def train_loop(m: torch.nn.Module): The recommended way to run QAT in torchao is through the `quantize_` API: 1. **Prepare:** specify how weights and/or activations are to be quantized through -[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`intx_quantization_aware_training`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242) +[`FakeQuantizeConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/qat/api.py#L29) and passing these to [`IntXQuantizationAwareTrainingConfig`](https://github.com/pytorch/ao/blob/cedadc741954f47a9e9efac2aa584701f125bc73/torchao/quantization/qat/api.py#L242) 2. **Convert:** quantize the model using the standard post-training quantization (PTQ) -functions such as [`int8_dynamic_activation_int4_weight`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606) +functions such as [`Int8DynamicActivationInt4WeightConfig`](https://github.com/pytorch/ao/blob/v0.7.0/torchao/quantization/quant_api.py#L606) For example: @@ -81,12 +81,12 @@ For example: ```python from torchao.quantization import ( quantize_, - int8_dynamic_activation_int4_weight, + Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.qat import ( FakeQuantizeConfig, - from_intx_quantization_aware_training, - intx_quantization_aware_training, + FromIntXQuantizationAwareTrainingConfig, + IntXQuantizationAwareTrainingConfig, ) model = get_model() @@ -96,7 +96,7 @@ activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=Fal weight_config = FakeQuantizeConfig(torch.int4, group_size=32) quantize_( model, - intx_quantization_aware_training(activation_config, weight_config), + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) # train @@ -105,8 +105,8 @@ train_loop(model) # convert: transform fake quantization ops into actual quantized ops # swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts # quantized activation and weight tensor subclasses -quantize_(model, from_intx_quantization_aware_training()) -quantize_(model, int8_dynamic_activation_int4_weight(group_size=32)) +quantize_(model, FromIntXQuantizationAwareTrainingConfig()) +quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32)) # inference or generate ``` @@ -117,7 +117,7 @@ the following with a filter function during the prepare step: ``` quantize_( m, - intx_quantization_aware_training(weight_config=weight_config), + IntXQuantizationAwareTrainingConfig(weight_config=weight_config), filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding), ) ``` From 32274726376a9e2956931f16c6fa88c1ebe0fc57 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 14 Feb 2025 13:53:59 -0800 Subject: [PATCH 081/115] make quantize_.set_inductor_config None by default (#1716) make quantize_.set_inductor_config None by default for future deprecation Summary: We want to migrate this to individual workflows, see https://github.com/pytorch/ao/issues/1715 for migration plan. This PR is step 1 where we enable distinguishing whether the user specified this argument or not. After this PR, we can control the behavior per-workflow, such as setting this functionality to False for future training workflows. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/README.md | 3 +++ torchao/quantization/quant_api.py | 13 +++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 655a942718..a0e2ea2cc4 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -386,6 +386,9 @@ The benchmarks below were run on a single NVIDIA-A6000 GPU. You try can out these apis with the `quantize_` api as above alongside the constructor `codebook_weight_only` an example can be found in in `torchao/_models/llama/generate.py`. ### Automatic Inductor Configuration + +:warning: This functionality is being migrated from the top level `quantize_` API to individual workflows, see https://github.com/pytorch/ao/issues/1715 for more details. + The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. ## (To be moved to prototype) A16W4 WeightOnly Quantization with GPTQ diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e347529929..0e7cda16f0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -488,7 +488,7 @@ def quantize_( model: torch.nn.Module, config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, - set_inductor_config: bool = True, + set_inductor_config: Optional[bool] = None, device: Optional[torch.types.Device] = None, ): """Convert the weight of linear modules in the model with `config`, model is modified inplace @@ -498,7 +498,7 @@ def quantize_( config (Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]]): either (1) a workflow configuration object or (2) a function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor). Note: (2) will be deleted in a future release. filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `config` on the weight of the module - set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) + set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to None) device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`. Defaults to None (do not change device). @@ -522,6 +522,15 @@ def quantize_( quantize_(m, int4_weight_only(group_size=32)) """ + if set_inductor_config != None: + warnings.warn( + """The `set_inductor_config` argument to `quantize_` will be removed in a future release. This functionality is being migrated to individual workflows. Please see https://github.com/pytorch/ao/issues/1715 for more details.""" + ) + else: # None + # for now, default to True to not change existing behavior when the + # argument is not specified + set_inductor_config = True + if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() From c3bb80e40f930d85e48b24c24556e499b4f6b947 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 14 Feb 2025 15:45:45 -0800 Subject: [PATCH 082/115] mx formats: create MXLinearConfig (#1688) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 23 +++++++---- torchao/prototype/mx_formats/config.py | 15 ++++++- torchao/prototype/mx_formats/mx_linear.py | 22 ++++++++-- torchao/prototype/mx_formats/mx_ops.py | 6 ++- torchao/prototype/mx_formats/mx_tensor.py | 46 +++++++++++++++++---- 5 files changed, 91 insertions(+), 21 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index ad718beb9c..2a15961586 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -7,7 +7,6 @@ import pytest import torch -from torchao.prototype.mx_formats import config from torchao.prototype.mx_formats.constants import ( DTYPE_FP4, DTYPE_FP6_E2M3, @@ -139,8 +138,14 @@ def test_exponent_nan_out(elem_dtype): else: raise AssertionError("unsupported") block_size = 2 + use_fp4_custom_triton_dequant_kernel = False tensor_mx = MXTensor( - scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float + scale_e8m0_bits, + data_bits, + elem_dtype, + block_size, + torch.float, + use_fp4_custom_triton_dequant_kernel, ) tensor_hp = tensor_mx.to_dtype(torch.float) assert torch.all(torch.isnan(tensor_hp[0:1])) @@ -188,15 +193,16 @@ def test_transpose(elem_dtype, fp4_triton): M, K = 128, 256 block_size = 32 tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) - config.use_fp4_custom_triton_dequant_kernel = fp4_triton + tensor_mx = MXTensor.to_mx( + tensor_hp, + elem_dtype, + block_size, + use_fp4_custom_triton_dequant_kernel=fp4_triton, + ) tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t() - config.use_fp4_custom_triton_dequant_kernel = False tensor_mx_t = tensor_mx.t() - config.use_fp4_custom_triton_dequant_kernel = fp4_triton tensor_mx_t_dq = tensor_mx_t.to_dtype(tensor_hp.dtype) - config.use_fp4_custom_triton_dequant_kernel = False assert tensor_mx_dq_t.shape == tensor_mx_t_dq.shape torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0) @@ -258,12 +264,14 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): to_dtype_c = torch.compile(to_dtype, fullgraph=True) + use_fp4_custom_triton_dequant_kernel = False x_mx_dq = to_dtype( x_mx._data, x_mx._scale_e8m0, x_mx._elem_dtype, x_mx._block_size, hp_dtype, # noqa: E501 + use_fp4_custom_triton_dequant_kernel, ) x_mx_c_dq = to_dtype_c( x_mx_c._data, @@ -271,5 +279,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx_c._elem_dtype, x_mx_c._block_size, hp_dtype, + use_fp4_custom_triton_dequant_kernel, ) torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 3e7e03d8f6..7b68b5b6a5 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -1,2 +1,13 @@ -# If True, uses a custom triton kernel for fp4 dequantize -use_fp4_custom_triton_dequant_kernel = False +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + + +@dataclass +class MXLinearConfig: + # If True, uses a custom triton kernel for fp4 dequantize + use_fp4_custom_triton_dequant_kernel: bool = False diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index d7aa744334..72c2b6ab39 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -8,11 +8,12 @@ Defines the prototype UX for converting a model to use mx weights """ -from typing import Any +from typing import Any, Optional import torch import torch.nn.functional as F +from torchao.prototype.mx_formats.config import MXLinearConfig from torchao.prototype.mx_formats.mx_tensor import MXTensor @@ -110,6 +111,8 @@ def from_float( elem_dtype_weight_override=None, elem_dtype_grad_output_override=None, *, + # TODO(next PR): move elem_dtype* and block size into config + config: MXLinearConfig = None, block_size=32, ): mod.__class__ = MXLinear @@ -117,6 +120,10 @@ def from_float( mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype mod.block_size = block_size + # TODO(next PR): fix this + if config is None: + config = MXLinearConfig() + mod.config = config return mod def forward(self, x): @@ -151,7 +158,9 @@ class MXInferenceLinear(torch.nn.Linear): @classmethod @torch.no_grad() - def from_float(cls, mod, elem_dtype, block_size): + def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): + # TODO(next PR): move elem_dtype and block_size into config + with torch.device("meta"): super_kwargs = { "in_features": mod.in_features, @@ -166,6 +175,7 @@ def from_float(cls, mod, elem_dtype, block_size): ) new_mod.bias = mod.bias new_mod.elem_dtype = elem_dtype + new_mod.config = config return new_mod @torch.no_grad() @@ -207,6 +217,8 @@ def swap_linear_with_mx_linear( elem_dtype_weight_override=None, elem_dtype_grad_output_override=None, *, + # TODO(next PR): move elem_dtype* and block_size into config + config: Optional[MXLinearConfig] = None, block_size=32, filter_fn=None, ): @@ -225,6 +237,7 @@ def __fn(mod, fqn): elem_dtype, elem_dtype_weight_override, elem_dtype_grad_output_override, + config=config, block_size=block_size, ), combined_filter_fn, @@ -236,6 +249,7 @@ def swap_linear_with_mx_inference_linear( elem_dtype, block_size, filter_fn=None, + config: Optional[MXLinearConfig] = None, ): if filter_fn is None: combined_filter_fn = _is_linear @@ -247,6 +261,8 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXInferenceLinear.from_float(mod, elem_dtype, block_size), + lambda mod: MXInferenceLinear.from_float( + mod, elem_dtype, block_size, config=config + ), combined_filter_fn, ) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 57fb0d54b4..5fb3e8c6c0 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -54,6 +54,7 @@ def mx_desugar_op(aten_op, args, kwargs=None): old._elem_dtype, old._block_size, old._orig_dtype, + old._use_fp4_custom_triton_dequant_kernel, ) return new @@ -82,6 +83,7 @@ def mx_t(aten_op, args, kwargs=None): old._elem_dtype, old._block_size, old._orig_dtype, + old._use_fp4_custom_triton_dequant_kernel, ) return new @@ -120,6 +122,7 @@ def mx_view_op(aten_op, args, kwargs=None): args[0]._elem_dtype, args[0]._block_size, args[0]._orig_dtype, + args[0]._use_fp4_custom_triton_dequant_kernel, ) @@ -130,7 +133,6 @@ def autocast_to_copy(aten_op, args, kwargs=None): tensor. """ assert isinstance(args[0], MXTensor) - # print('before', args[0], args[0].dtype, args[0]._orig_dtype) assert ( len(kwargs) == 1 and "dtype" in kwargs ), "Only support dtype kwarg for autocast" @@ -144,6 +146,6 @@ def autocast_to_copy(aten_op, args, kwargs=None): args[0]._elem_dtype, args[0]._block_size, kwargs["dtype"], + args[0]._use_fp4_custom_triton_dequant_kernel, ) - # print('after', res, res.dtype, res._orig_dtype) return res diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 801f29ac3c..838ab2338c 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -21,7 +21,6 @@ import torch -import torchao.prototype.mx_formats.config as config from torchao.prototype.mx_formats.constants import ( BLOCK_SIZE_DEFAULT, DTYPE_FP4, @@ -239,7 +238,14 @@ def get_fp_scale(scale_e8m0): return s_fp -def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype): +def to_dtype( + data_lp, + scale_e8m0, + elem_dtype, + block_size, + target_dtype, + use_fp4_custom_triton_dequant_kernel, +): orig_shape = data_lp.shape is_transposed = not data_lp.is_contiguous() # if the underlying data is transposed, convert to row major before @@ -258,7 +264,7 @@ def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype): data_hp = f6_e3m2_unpacked_to_f32(data_lp) data_hp = data_hp.to(target_dtype) elif elem_dtype == DTYPE_FP4: - if config.use_fp4_custom_triton_dequant_kernel: + if use_fp4_custom_triton_dequant_kernel: data_hp_rescaled = triton_f4_to_scaled_bf16( data_lp, scale_e8m0, @@ -318,17 +324,29 @@ class ToMXConstrFunc(torch.autograd.Function): """ @staticmethod - def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode): + def forward( + ctx, + data_hp, + elem_dtype, + block_size, + scaling_mode, + use_fp4_custom_triton_dequant_kernel, + ): scale_e8m0_biased, data_lp = to_mx( data_hp, elem_dtype, block_size, scaling_mode ) return MXTensor( - scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype + scale_e8m0_biased, + data_lp, + elem_dtype, + block_size, + data_hp.dtype, + use_fp4_custom_triton_dequant_kernel, ) @staticmethod def backward(ctx, g): - return g, None, None, None + return g, None, None, None, None @torch._dynamo.allow_in_graph @@ -345,6 +363,7 @@ def forward(ctx, tensor_lp, target_dtype): tensor_lp._elem_dtype, tensor_lp._block_size, target_dtype, + tensor_lp._use_fp4_custom_triton_dequant_kernel, ) @staticmethod @@ -360,6 +379,7 @@ def __new__( elem_dtype, block_size, orig_dtype, + use_fp4_custom_triton_dequant_kernel, ): new_size = data_bits.size() if elem_dtype == DTYPE_FP4: @@ -417,6 +437,9 @@ def __new__( self._elem_dtype = elem_dtype self._block_size = block_size self._orig_dtype = orig_dtype + self._use_fp4_custom_triton_dequant_kernel = ( + use_fp4_custom_triton_dequant_kernel + ) return self def __repr__(self): @@ -443,14 +466,22 @@ def to_mx( elem_dtype: Union[torch.dtype, str], block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, + use_fp4_custom_triton_dequant_kernel: bool = False, ): - return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode) + return ToMXConstrFunc.apply( + data_hp, + elem_dtype, + block_size, + scaling_mode, + use_fp4_custom_triton_dequant_kernel, + ) def __tensor_flatten__(self): ctx = { "_elem_dtype": self._elem_dtype, "_block_size": self._block_size, "_orig_dtype": self._orig_dtype, + "_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel, } return ["_scale_e8m0", "_data"], ctx @@ -467,6 +498,7 @@ def __tensor_unflatten__( metadata["_elem_dtype"], metadata["_block_size"], metadata["_orig_dtype"], + metadata["_use_fp4_custom_triton_dequant_kernel"], ) # Do not force the MXTensor type on the returned tensor From 40d01cd08168eb6428dc17bb40a474ed4bbde7d2 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 14 Feb 2025 15:46:52 -0800 Subject: [PATCH 083/115] MX: move block_size and elem_dtype into MXLinearConfig (#1689) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 47 +++++++--------- torchao/prototype/mx_formats/README.md | 11 ++-- torchao/prototype/mx_formats/config.py | 31 +++++++++++ torchao/prototype/mx_formats/mx_linear.py | 60 +++++++-------------- 4 files changed, 74 insertions(+), 75 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 17a76a750d..c2eb66960f 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn +from torchao.prototype.mx_formats.config import MXLinearConfig from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES from torchao.prototype.mx_formats.mx_linear import ( MXInferenceLinear, @@ -59,8 +60,13 @@ def test_linear_eager(elem_dtype, bias, input_shape): nn.Linear(8, 6, bias=bias, device="cuda"), ) m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size) + config = MXLinearConfig( + block_size=2, + elem_dtype=elem_dtype[0], + elem_dtype_weight_override=elem_dtype[1], + elem_dtype_grad_output_override=elem_dtype[2], + ) + swap_linear_with_mx_linear(m_mx, config=config) x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() x = copy.deepcopy(x_ref) @@ -97,8 +103,8 @@ def test_activation_checkpointing(): nn.Linear(4, 6, bias=True, device="cuda"), nn.Linear(6, 6, bias=True, device="cuda"), ) - block_size = 2 - swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_linear(m, config=config) x = torch.randn(*input_shape, device="cuda").requires_grad_() g = torch.randn(*grad_shape, device="cuda") @@ -133,8 +139,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast): m_mx = nn.Sequential( nn.Linear(K, N, bias=bias, device="cuda"), ) - block_size = 2 - swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_linear(m_mx, config=config) m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") @@ -181,8 +187,8 @@ def test_inference_linear(elem_dtype, bias, input_shape): m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16)) m = m.cuda() m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_inference_linear(m_mx, config=config) x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16) y_ref = m(x) @@ -209,8 +215,8 @@ def test_inference_compile_simple(elem_dtype): m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) m = m.cuda() m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_inference_linear(m_mx, config=config) m_mx = torch.compile(m_mx, fullgraph="true") x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16) @@ -223,20 +229,6 @@ def test_inference_compile_simple(elem_dtype): assert sqnr >= 13.5 -def test_mx_linear_input_weight_gradient_dtypes(): - m = nn.Sequential(nn.Linear(32, 32)) - swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32) - assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0] - assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1] - assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2] - - m = nn.Sequential(nn.Linear(32, 32)) - swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32) - assert m[0].in_elem_dtype == torch.float8_e4m3fn - assert m[0].w_elem_dtype == torch.float8_e4m3fn - assert m[0].grad_elem_dtype == torch.float8_e4m3fn - - def test_filter_fn(): m1 = nn.Sequential( nn.Linear(32, 32), @@ -245,12 +237,11 @@ def test_filter_fn(): m2 = copy.deepcopy(m1) filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731 - swap_linear_with_mx_linear( - m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn - ) + config = MXLinearConfig(block_size=32) + swap_linear_with_mx_linear(m1, config=config, filter_fn=filter_fn) assert type(m1[0]) == MXLinear assert type(m1[1]) == torch.nn.Linear - swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501 + swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501 assert type(m2[0]) == MXInferenceLinear assert type(m2[1]) == torch.nn.Linear diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 32f45e3755..09e7563ebb 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -41,10 +41,11 @@ This is a module to do MX training, the MX matmul is currently emulated. ```python from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear +from torchao.prototype.mx_formats.config import MXLinearConfig m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -elem_dtype = torch.float8_e4m3fn -swap_linear_with_mx_linear(m, elem_dtype, block_size=32) +config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32) +swap_linear_with_mx_linear(m, config=config) # training loop (not shown) ``` @@ -55,11 +56,11 @@ This is a module to do MX inference, weights are in MX and matmul is in high pre ```python from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear +from torchao.prototype.mx_formats.config import MXLinearConfig m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -elem_dtype = torch.float8_e4m3fn -block_size = 32 -swap_linear_with_mx_inference_linear(m, elem_dtype, block_size) +config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32) +swap_linear_with_mx_inference_linear(m, config=config) # do inference (not shown) ``` diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 7b68b5b6a5..7cdf2d4e58 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -5,9 +5,40 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from typing import Any, Optional + +import torch + +from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES @dataclass class MXLinearConfig: + # block size for scaling, default is 32 to match + # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, + # section 5.2 + block_size: int = 32 + + # element dtype, used for activations, weights and gradients + elem_dtype: Any = torch.float8_e4m3fn + + # overrides for element dtype for weights and gradients + # TODO(future PR): refactor to make this cleaner + elem_dtype_weight_override: Optional[Any] = None + elem_dtype_grad_output_override: Optional[Any] = None + # If True, uses a custom triton kernel for fp4 dequantize use_fp4_custom_triton_dequant_kernel: bool = False + + def __post_init__(self): + assert ( + self.elem_dtype in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" + if self.elem_dtype_weight_override is not None: + assert ( + self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" + if self.elem_dtype_grad_output_override is not None: + assert ( + self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 72c2b6ab39..a38a8c5499 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -107,22 +107,11 @@ class MXLinear(torch.nn.Linear): def from_float( cls, mod, - elem_dtype, - elem_dtype_weight_override=None, - elem_dtype_grad_output_override=None, - *, - # TODO(next PR): move elem_dtype* and block size into config - config: MXLinearConfig = None, - block_size=32, + config: Optional[MXLinearConfig] = MXLinearConfig(), ): + # TODO(before land): remove this + assert isinstance(config, MXLinearConfig) mod.__class__ = MXLinear - mod.in_elem_dtype = elem_dtype - mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype - mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype - mod.block_size = block_size - # TODO(next PR): fix this - if config is None: - config = MXLinearConfig() mod.config = config return mod @@ -135,13 +124,14 @@ def forward(self, x): else: w = self.weight + config = self.config y = mx_mm.apply( x, w, - self.in_elem_dtype, - self.w_elem_dtype, - self.grad_elem_dtype, - self.block_size, + config.elem_dtype, + config.elem_dtype_weight_override or config.elem_dtype, + config.elem_dtype_grad_output_override or config.elem_dtype, + config.block_size, ) if self.bias is not None: y = y + self.bias @@ -158,9 +148,11 @@ class MXInferenceLinear(torch.nn.Linear): @classmethod @torch.no_grad() - def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): - # TODO(next PR): move elem_dtype and block_size into config - + def from_float( + cls, + mod, + config: Optional[MXLinearConfig] = MXLinearConfig(), + ): with torch.device("meta"): super_kwargs = { "in_features": mod.in_features, @@ -171,10 +163,9 @@ def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): # TODO(future PR): set to new_mod.weight directly, will need to work # through some errors new_mod.weight_mx = MXTensor.to_mx( - mod.weight, elem_dtype, block_size=block_size + mod.weight, config.elem_dtype, block_size=config.block_size ) new_mod.bias = mod.bias - new_mod.elem_dtype = elem_dtype new_mod.config = config return new_mod @@ -213,13 +204,8 @@ def _is_linear(mod, fqn): def swap_linear_with_mx_linear( model, - elem_dtype, - elem_dtype_weight_override=None, - elem_dtype_grad_output_override=None, *, - # TODO(next PR): move elem_dtype* and block_size into config config: Optional[MXLinearConfig] = None, - block_size=32, filter_fn=None, ): if filter_fn is None: @@ -232,24 +218,16 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXLinear.from_float( - mod, - elem_dtype, - elem_dtype_weight_override, - elem_dtype_grad_output_override, - config=config, - block_size=block_size, - ), + lambda mod: MXLinear.from_float(mod, config=config), combined_filter_fn, ) def swap_linear_with_mx_inference_linear( model, - elem_dtype, - block_size, - filter_fn=None, + *, config: Optional[MXLinearConfig] = None, + filter_fn=None, ): if filter_fn is None: combined_filter_fn = _is_linear @@ -261,8 +239,6 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXInferenceLinear.from_float( - mod, elem_dtype, block_size, config=config - ), + lambda mod: MXInferenceLinear.from_float(mod, config=config), combined_filter_fn, ) From 8fc49fe0cb725a159f1bb0b1262d531b4655efdb Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Fri, 14 Feb 2025 15:48:55 -0800 Subject: [PATCH 084/115] MX: hook up mxfp8 and mxfp4 CUTLASS kernels to MXLinear (#1713) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 51 +++++++++++++++++++-- test/prototype/mx_formats/test_mx_tensor.py | 2 + torchao/prototype/mx_formats/README.md | 15 +++++- torchao/prototype/mx_formats/config.py | 38 ++++++++++++++- torchao/prototype/mx_formats/mx_linear.py | 43 +++++++++++++---- torchao/prototype/mx_formats/mx_ops.py | 42 ++++++++++++++--- torchao/prototype/mx_formats/mx_tensor.py | 11 ++++- 7 files changed, 180 insertions(+), 22 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index c2eb66960f..87451bf621 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -11,8 +11,8 @@ import torch import torch.nn as nn -from torchao.prototype.mx_formats.config import MXLinearConfig -from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES +from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig +from torchao.prototype.mx_formats.constants import DTYPE_FP4, SUPPORTED_ELEM_DTYPES from torchao.prototype.mx_formats.mx_linear import ( MXInferenceLinear, MXLinear, @@ -50,7 +50,9 @@ def run_around_tests(): @pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)]) def test_linear_eager(elem_dtype, bias, input_shape): """ - Smoke test for training linear module with mx weight + Smoke test for training linear module with mx weight, compares the following: + * baseline: float32 + * experiment: emulated MX """ # elem_dtype is a tuple of (input, weight, gradient) dtypes. grad_shape = list(input_shape) @@ -92,6 +94,49 @@ def test_linear_eager(elem_dtype, bias, input_shape): assert x_g_sqnr >= 8.0 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" +) +@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, DTYPE_FP4]) +@pytest.mark.parametrize("mkn", [(128, 256, 512), (256, 512, 128), (512, 128, 256)]) +def test_linear_eager_emulated_vs_real_gemm(elem_dtype, mkn): + M, K, N = 128, 128, 128 + M, K, N = mkn + + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").requires_grad_() + x_copy = copy.deepcopy(x) + g = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) + m_emulated = nn.Sequential( + nn.Linear(K, N, bias=False, device="cuda", dtype=torch.bfloat16), + ) + m_real = copy.deepcopy(m_emulated) + + config_emulated = MXLinearConfig(block_size=32, elem_dtype=elem_dtype) + config_real = MXLinearConfig( + block_size=32, + elem_dtype=elem_dtype, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + ) + + swap_linear_with_mx_linear(m_emulated, config=config_emulated) + swap_linear_with_mx_linear(m_real, config=config_real) + + y_emulated = m_emulated(x) + y_emulated.backward(g) + + y_real = m_real(x_copy) + y_real.backward(g) + + with torch.no_grad(): + y_sqnr = compute_error(y_real, y_emulated) + w_sqnr = compute_error(m_real[0].weight.grad, m_emulated[0].weight.grad) + g_sqnr = compute_error(x_copy.grad, x.grad) + assert y_sqnr > 100.0, f"y_sqnr {y_sqnr} too low!" + assert w_sqnr > 100.0, f"w_sqnr {w_sqnr} too low!" + assert g_sqnr > 100.0, f"g_sqnr {g_sqnr} too low!" + + # TODO(future): enable compile support @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_activation_checkpointing(): diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 2a15961586..f5014b7e31 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -7,6 +7,7 @@ import pytest import torch +from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( DTYPE_FP4, DTYPE_FP6_E2M3, @@ -146,6 +147,7 @@ def test_exponent_nan_out(elem_dtype): block_size, torch.float, use_fp4_custom_triton_dequant_kernel, + MXGemmKernelChoice.EMULATED, ) tensor_hp = tensor_mx.to_dtype(torch.float) assert torch.all(torch.isnan(tensor_hp[0:1])) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 09e7563ebb..1f1db18b7d 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -41,10 +41,21 @@ This is a module to do MX training, the MX matmul is currently emulated. ```python from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear -from torchao.prototype.mx_formats.config import MXLinearConfig +from torchao.prototype.mx_formats.config import MXLinearConfig, MXGemmKernelChoice +from torchao.utils import is_sm_at_least_100 + +# early prototype: on MX-enabled hardware, you can use the real MX gemm backed by +# torchao's CUTLASS kernels. In the future, we will also add cuBLAS kernel support. +gemm_kernel_choice = MXGemmKernelChoice.EMULATED +if is_sm_at_least_100(): + gemm_kernel_choice = MXGemmKernelChoice.CUTLASS m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32) +config = MXLinearConfig( + elem_dtype=torch.float8_e4m3fn, + block_size=32, + gemm_kernel_choice=gemm_kernel_choice, +) swap_linear_with_mx_linear(m, config=config) # training loop (not shown) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 7cdf2d4e58..d511d2614d 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -5,11 +5,26 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from enum import Enum from typing import Any, Optional import torch -from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES +from torchao.prototype.mx_formats.constants import ( + DTYPE_FP4, + SUPPORTED_ELEM_DTYPES, +) + + +class MXGemmKernelChoice(Enum): + # always available - MX operands are dequantized and a high precision + # gemm is run + EMULATED = "emulated" + + # available only when CUDA capability is greater than or equal to 10.0 + CUTLASS = "cutlass" + + # TODO(future PR): add cuBLAS here once we land pytorch/pytorch support @dataclass @@ -27,10 +42,15 @@ class MXLinearConfig: elem_dtype_weight_override: Optional[Any] = None elem_dtype_grad_output_override: Optional[Any] = None + # defines the gemm kernel choice, if the chosen kernel is not supported + # on the given hardware an exception will be thrown + gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED + # If True, uses a custom triton kernel for fp4 dequantize use_fp4_custom_triton_dequant_kernel: bool = False def __post_init__(self): + # validate elem_dtype and its overrides assert ( self.elem_dtype in SUPPORTED_ELEM_DTYPES ), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" @@ -42,3 +62,19 @@ def __post_init__(self): assert ( self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES ), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" + + # validate that block size and elem_dtype matches kernel choice + if self.gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: + assert ( + self.block_size == 32 + ), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {self.block_size}" + valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4] + assert ( + self.elem_dtype in valid_dtypes + ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}" + assert ( + self.elem_dtype_weight_override is None + ), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels" + assert ( + self.elem_dtype_grad_output_override is None + ), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels" diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index a38a8c5499..e15f2ad727 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -13,7 +13,7 @@ import torch import torch.nn.functional as F -from torchao.prototype.mx_formats.config import MXLinearConfig +from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig from torchao.prototype.mx_formats.mx_tensor import MXTensor @@ -36,19 +36,25 @@ def forward( w_elem_dtype: Any, grad_elem_dtype: Any, block_size: int, + gemm_kernel_choice: MXGemmKernelChoice, ): ctx.save_for_backward(input_hp, weight_hp) ctx.in_elem_dtype = in_elem_dtype ctx.w_elem_dtype = w_elem_dtype ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size + ctx.gemm_kernel_choice = gemm_kernel_choice # input @ weight_t = output input_orig_shape = input_hp.shape input_hp_r = input_hp.reshape(-1, input_orig_shape[-1]) - input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, in_elem_dtype, block_size) - weight_mx_dim0 = MXTensor.to_mx(weight_hp, w_elem_dtype, block_size) + input_mx_r_dim0 = MXTensor.to_mx( + input_hp_r, in_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice + ) + weight_mx_dim0 = MXTensor.to_mx( + weight_hp, w_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice + ) output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t()) output = output.reshape(*input_orig_shape[:-1], output.shape[-1]) @@ -62,6 +68,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): w_elem_dtype = ctx.w_elem_dtype grad_elem_dtype = ctx.grad_elem_dtype block_size = ctx.block_size + gemm_kernel_choice = ctx.gemm_kernel_choice grad_output_orig_shape = grad_output_hp.shape grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1]) @@ -71,9 +78,17 @@ def backward(ctx, grad_output_hp: torch.Tensor): # grad_output @ weight = grad_input grad_output_mx_dim0 = MXTensor.to_mx( - grad_output_hp_r, grad_elem_dtype, block_size + grad_output_hp_r, + grad_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, + ) + weight_mx_dim1 = MXTensor.to_mx( + weight_hp_t_c, + w_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, ) - weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size) grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) grad_input = grad_input.reshape( *grad_output_orig_shape[:-1], grad_input.shape[-1] @@ -81,15 +96,21 @@ def backward(ctx, grad_output_hp: torch.Tensor): # input_t @ grad_output = grad_weight grad_output_mx_dim1 = MXTensor.to_mx( - grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size + grad_output_hp_r.t().contiguous(), + grad_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, ) input_t_mx_dim0_tmp = MXTensor.to_mx( - input_hp_r.t().contiguous(), in_elem_dtype, block_size + input_hp_r.t().contiguous(), + in_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0) - return grad_input, grad_weight, None, None, None, None + return grad_input, grad_weight, None, None, None, None, None class MXLinear(torch.nn.Linear): @@ -132,6 +153,7 @@ def forward(self, x): config.elem_dtype_weight_override or config.elem_dtype, config.elem_dtype_grad_output_override or config.elem_dtype, config.block_size, + config.gemm_kernel_choice, ) if self.bias is not None: y = y + self.bias @@ -163,7 +185,10 @@ def from_float( # TODO(future PR): set to new_mod.weight directly, will need to work # through some errors new_mod.weight_mx = MXTensor.to_mx( - mod.weight, config.elem_dtype, block_size=config.block_size + mod.weight, + config.elem_dtype, + block_size=config.block_size, + gemm_kernel_choice=config.gemm_kernel_choice, ) new_mod.bias = mod.bias new_mod.config = config diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 5fb3e8c6c0..16e61e0653 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -22,11 +22,15 @@ import torch from torch.utils._pytree import tree_map +# from torchao.ops import mx_fp4_bf16, mx_fp8_bf16 +import torchao.ops +from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import DTYPE_FP4 from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501 MXTensor, tensor_size_hp_to_fp4x2, ) +from torchao.prototype.mx_formats.utils import to_blocked aten = torch.ops.aten @@ -55,6 +59,7 @@ def mx_desugar_op(aten_op, args, kwargs=None): old._block_size, old._orig_dtype, old._use_fp4_custom_triton_dequant_kernel, + old._gemm_kernel_choice, ) return new @@ -64,12 +69,34 @@ def mx_mm(aten_op, args, kwargs=None): a = args[0] b = args[1] assert isinstance(a, MXTensor) and isinstance(b, MXTensor) - a_hp = a.to_dtype(a._orig_dtype) - b_hp = b.to_dtype(b._orig_dtype) - # assert memory layout we expect to be required in hardware - assert a_hp.is_contiguous() - assert b_hp.t().is_contiguous() - res = aten_op(a_hp, b_hp) + assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported" + if a._gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: + # real MX gemm backed by torchao's CUTLASS kernels + M, K, N = a.shape[0], a.shape[1], b.shape[1] + assert b._data.t().is_contiguous() + a_scale = a._scale_e8m0.view(M, K // 32) + b_scale = b._scale_e8m0.view(N, K // 32) + a_scale_block = to_blocked(a_scale) + b_scale_block = to_blocked(b_scale) + if a._elem_dtype == torch.float8_e4m3fn: + assert b._elem_dtype == torch.float8_e4m3fn + res = torchao.ops.mx_fp8_bf16( + a._data, b._data, a_scale_block, b_scale_block + ) + else: + assert a._elem_dtype == DTYPE_FP4 + assert b._elem_dtype == DTYPE_FP4 + res = torchao.ops.mx_fp4_bf16( + a._data, b._data, a_scale_block, b_scale_block + ) + else: + # emulated MX gemm + a_hp = a.to_dtype(a._orig_dtype) + b_hp = b.to_dtype(b._orig_dtype) + # assert memory layout we expect to be required in hardware + assert a_hp.is_contiguous() + assert b_hp.t().is_contiguous() + res = aten_op(a_hp, b_hp) return res @@ -84,6 +111,7 @@ def mx_t(aten_op, args, kwargs=None): old._block_size, old._orig_dtype, old._use_fp4_custom_triton_dequant_kernel, + old._gemm_kernel_choice, ) return new @@ -123,6 +151,7 @@ def mx_view_op(aten_op, args, kwargs=None): args[0]._block_size, args[0]._orig_dtype, args[0]._use_fp4_custom_triton_dequant_kernel, + args[0]._gemm_kernel_choice, ) @@ -147,5 +176,6 @@ def autocast_to_copy(aten_op, args, kwargs=None): args[0]._block_size, kwargs["dtype"], args[0]._use_fp4_custom_triton_dequant_kernel, + args[0]._gemm_kernel_choice, ) return res diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 838ab2338c..6c0a718c78 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -21,6 +21,7 @@ import torch +from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( BLOCK_SIZE_DEFAULT, DTYPE_FP4, @@ -331,6 +332,7 @@ def forward( block_size, scaling_mode, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ): scale_e8m0_biased, data_lp = to_mx( data_hp, elem_dtype, block_size, scaling_mode @@ -342,11 +344,12 @@ def forward( block_size, data_hp.dtype, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ) @staticmethod def backward(ctx, g): - return g, None, None, None, None + return g, None, None, None, None, None @torch._dynamo.allow_in_graph @@ -380,6 +383,7 @@ def __new__( block_size, orig_dtype, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ): new_size = data_bits.size() if elem_dtype == DTYPE_FP4: @@ -440,6 +444,7 @@ def __new__( self._use_fp4_custom_triton_dequant_kernel = ( use_fp4_custom_triton_dequant_kernel ) + self._gemm_kernel_choice = gemm_kernel_choice return self def __repr__(self): @@ -467,6 +472,7 @@ def to_mx( block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, use_fp4_custom_triton_dequant_kernel: bool = False, + gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, ): return ToMXConstrFunc.apply( data_hp, @@ -474,6 +480,7 @@ def to_mx( block_size, scaling_mode, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ) def __tensor_flatten__(self): @@ -482,6 +489,7 @@ def __tensor_flatten__(self): "_block_size": self._block_size, "_orig_dtype": self._orig_dtype, "_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel, + "_gemm_kernel_choice": self._gemm_kernel_choice, } return ["_scale_e8m0", "_data"], ctx @@ -499,6 +507,7 @@ def __tensor_unflatten__( metadata["_block_size"], metadata["_orig_dtype"], metadata["_use_fp4_custom_triton_dequant_kernel"], + metadata["_gemm_kernel_choice"], ) # Do not force the MXTensor type on the returned tensor From 22d7d51e73954d5d70189d18407b56cd10d852f4 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Mon, 17 Feb 2025 18:16:44 -0800 Subject: [PATCH 085/115] Reformat (#1723) * reformat * up --- .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 75 +- .../kernels/cpu/aarch64/tests/test_linear.cpp | 638 +++------ .../linear_8bit_act_xbit_weight.cpp | 245 ++-- .../linear_8bit_act_xbit_weight.h | 129 +- .../op_linear_8bit_act_xbit_weight-impl.h | 341 ++--- .../test_linear_8bit_act_xbit_weight.cpp | 1275 ++++++----------- 6 files changed, 962 insertions(+), 1741 deletions(-) diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 167ccc47df..9cde684995 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -23,16 +23,14 @@ namespace torchao::kernels::cpu::aarch64::kleidi { // Helper functions // TODO: find a better place for these? -size_t roundup(size_t a, size_t b) { - return ((a + b - 1) / b) * b; -} +size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } uint16_t get_bf16_from_float(float f) { uint16_t bf16; #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ memcpy(&bf16, &f, sizeof(uint16_t)); #else - const void* fp = reinterpret_cast( + const void *fp = reinterpret_cast( reinterpret_cast(&f) + sizeof(float) - sizeof(uint16_t)); memcpy(&bf16, fp, sizeof(uint16_t)); #endif // __BYTE_ORDER__ @@ -45,52 +43,31 @@ using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; size_t activation_data_size(const Ukernel ukernel, int m, int k) { auto lhs_packing = get_lhs_packing(); - return lhs_packing.get_lhs_packed_size( - m, k, ukernel.get_mr(), ukernel.get_kr(), ukernel.get_sr()); + return lhs_packing.get_lhs_packed_size(m, k, ukernel.get_mr(), + ukernel.get_kr(), ukernel.get_sr()); } -void prepare_activation_data( - const Ukernel ukernel, - void* activation_data, - int m, - int k, - const float* activations) { +void prepare_activation_data(const Ukernel ukernel, void *activation_data, + int m, int k, const float *activations) { auto lhs_pack = get_lhs_packing(); - lhs_pack.run_lhs_pack( - m, - k, - ukernel.get_mr(), - ukernel.get_kr(), - ukernel.get_sr(), - /*m_index_start=*/0, - activations, - /*lhs_stride=*/k * sizeof(float), - activation_data); + lhs_pack.run_lhs_pack(m, k, ukernel.get_mr(), ukernel.get_kr(), + ukernel.get_sr(), + /*m_index_start=*/0, activations, + /*lhs_stride=*/k * sizeof(float), activation_data); } size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) { auto rhs_pack = get_rhs_packing(); - return rhs_pack.get_rhs_packed_size( - n, - k, - ukernel.get_nr(), - ukernel.get_kr(), - ukernel.get_sr(), - group_size, - kai_datatype::kai_dt_bf16); + return rhs_pack.get_rhs_packed_size(n, k, ukernel.get_nr(), ukernel.get_kr(), + ukernel.get_sr(), group_size, + kai_datatype::kai_dt_bf16); } -void prepare_weight_data( - const Ukernel ukernel, - void* weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { +void prepare_weight_data(const Ukernel ukernel, void *weight_data, int n, int k, + int group_size, const int8_t *weight_qvals, + const float *weight_scales, const int8_t *weight_zeros, + const float *bias) { // TODO(T204312268) - remove this constraint and pad when possible assert(n % 2 == 0); @@ -123,25 +100,19 @@ void prepare_weight_data( } // Parameters for packing - rhs_packing::qparams_t qparams{ - .lhs_zero_point = 1, - .rhs_zero_point = wzp, - .scale_dt = kai_datatype::kai_dt_bf16}; + rhs_packing::qparams_t qparams{.lhs_zero_point = 1, + .rhs_zero_point = wzp, + .scale_dt = kai_datatype::kai_dt_bf16}; auto rhs_pack = get_rhs_packing(); rhs_pack.run_rhs_pack( - /*groups=*/1, - n, - k, - ukernel.get_nr(), - ukernel.get_kr(), - ukernel.get_sr(), + /*groups=*/1, n, k, ukernel.get_nr(), ukernel.get_kr(), ukernel.get_sr(), group_size, - /*rhs=*/reinterpret_cast(packed_weight_qvals.data()), + /*rhs=*/reinterpret_cast(packed_weight_qvals.data()), /*rhs_stride=*/roundup(k, 2) / 2, /*bias=*/bias, - /*scale=*/reinterpret_cast(weight_scales_bf16.data()), + /*scale=*/reinterpret_cast(weight_scales_bf16.data()), /*scale_stride=*/sizeof(uint16_t) * (roundup(k, group_size) / group_size), /*rhs_packed=*/weight_data, /*extra_bytes=*/0, diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index f68106c7e8..070e7bebfb 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -27,55 +27,33 @@ float kTol = 0.0001; template void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot( - int m, - int k, - int n, - int group_size) { - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp); + int m, int k, int n, int group_size) { + auto test_case = + torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: + generate(m, k, n, group_size, weight_nbit, has_weight_zeros, has_bias, + has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot; std::vector activation_data( activation_data_size(m, k, group_size)); - prepare_activation_data( - (void*)activation_data.data(), - m, - k, - group_size, - test_case.activations.data()); + prepare_activation_data((void *)activation_data.data(), m, + k, group_size, + test_case.activations.data()); std::vector weight_data( - weight_data_size( - n, k, group_size)); + weight_data_size(n, k, + group_size)); prepare_weight_data( - (void*)weight_data.data(), - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - test_case.bias.data()); + (void *)weight_data.data(), n, k, group_size, + test_case.weight_qvals.data(), test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data(), test_case.bias.data()); std::vector output(m * n); kernel( output.data(), - /*output_m_stride=*/n, - m, - n, - k, - group_size, - weight_data.data(), + /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), activation_data.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max); @@ -89,9 +67,7 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, Standard) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); } @@ -100,9 +76,7 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasWeightZeros) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, + 4 /*weight_nbit*/, true /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); } @@ -111,9 +85,7 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasBias) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, false /*has_clamp*/>( /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); } @@ -122,64 +94,40 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot, HasClamp) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, true /*has_clamp*/>( /*m=*/7, /*k=*/128, /*n=*/13, /*group_size=*/32); } template void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot( - int m, - int k, - int n, - int group_size) { - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp); + int m, int k, int n, int group_size) { + auto test_case = + torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: + generate(m, k, n, group_size, weight_nbit, has_weight_zeros, has_bias, + has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot; std::vector activation_data( activation_data_size(m, k, group_size)); - prepare_activation_data( - (void*)activation_data.data(), - m, - k, - group_size, - test_case.activations.data()); + prepare_activation_data((void *)activation_data.data(), m, + k, group_size, + test_case.activations.data()); std::vector weight_data( - weight_data_size( - n, k, group_size)); + weight_data_size(n, k, + group_size)); prepare_weight_data( - (void*)weight_data.data(), - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - test_case.bias.data()); + (void *)weight_data.data(), n, k, group_size, + test_case.weight_qvals.data(), test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data(), test_case.bias.data()); std::vector output(m * n); kernel( output.data(), - /*output_m_stride=*/n, - m, - n, - k, - group_size, - weight_data.data(), + /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), activation_data.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max); @@ -193,9 +141,7 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, Standard) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } @@ -204,9 +150,7 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasWeightZeros) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, + 4 /*weight_nbit*/, true /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } @@ -215,9 +159,7 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasBias) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, false /*has_clamp*/>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } @@ -226,9 +168,7 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot, HasClamp) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, true /*has_clamp*/>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } @@ -238,9 +178,7 @@ TEST( NLessThan4) { for (int n = 1; n < 4; n++) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, true /*has_clamp*/>( /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); } @@ -248,55 +186,33 @@ TEST( template void test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot( - int m, - int k, - int n, - int group_size) { - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp); + int m, int k, int n, int group_size) { + auto test_case = + torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: + generate(m, k, n, group_size, weight_nbit, has_weight_zeros, has_bias, + has_clamp); using namespace torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; std::vector activation_data( activation_data_size(m, k, group_size)); - prepare_activation_data( - (void*)activation_data.data(), - m, - k, - group_size, - test_case.activations.data()); + prepare_activation_data((void *)activation_data.data(), m, + k, group_size, + test_case.activations.data()); std::vector weight_data( - weight_data_size( - n, k, group_size)); + weight_data_size(n, k, + group_size)); prepare_weight_data( - (void*)weight_data.data(), - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - test_case.bias.data()); + (void *)weight_data.data(), n, k, group_size, + test_case.weight_qvals.data(), test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data(), test_case.bias.data()); std::vector output(m * n); kernel( output.data(), - /*output_m_stride=*/n, - m, - n, - k, - group_size, - weight_data.data(), + /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), activation_data.data(), /*clamp_min=*/test_case.clamp_min, /*clamp_max=*/test_case.clamp_max); @@ -310,9 +226,7 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, Standard) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } @@ -321,9 +235,7 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasWeightZeros) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - false /*has_bias*/, + 4 /*weight_nbit*/, true /*has_weight_zeros*/, false /*has_bias*/, false /*has_clamp*/>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } @@ -332,9 +244,7 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasBias) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, false /*has_clamp*/>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } @@ -343,9 +253,7 @@ TEST( test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot, HasClamp) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, true /*has_clamp*/>( /*m=*/7, /*k=*/64, /*n=*/13, /*group_size=*/16); } @@ -355,9 +263,7 @@ TEST( NLessThan8) { for (int n = 1; n < 8; n++) { test_channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, true /*has_clamp*/>( /*m=*/7, /*k=*/64, /*n=*/n, /*group_size=*/16); } @@ -366,458 +272,322 @@ TEST( #ifdef TORCHAO_ENABLE_KLEIDI template void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( - int m, - int k, - int n, - int group_size) { - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - /*weight_nbit=*/4, - /*has_weight_zeros*/ false, - has_bias, - has_clamp, - /*weight_scale_bf16_round_trip=*/true); + int m, int k, int n, int group_size) { + auto test_case = + torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: + generate(m, k, n, group_size, + /*weight_nbit=*/4, + /*has_weight_zeros*/ false, has_bias, has_clamp, + /*weight_scale_bf16_round_trip=*/true); using namespace torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; std::vector activation_data(activation_data_size(m, k, group_size)); - prepare_activation_data( - (void*)activation_data.data(), - m, - k, - group_size, - test_case.activations.data()); + prepare_activation_data((void *)activation_data.data(), m, k, group_size, + test_case.activations.data()); std::vector weight_data(weight_data_size(n, k, group_size)); - prepare_weight_data( - (void*)weight_data.data(), - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); + prepare_weight_data((void *)weight_data.data(), n, k, group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data(), + /*bias=*/test_case.bias.data()); std::vector output(m * n); - kernel( - output.data(), - /*output_m_stride=*/n, - m, - n, - k, - group_size, - weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); + kernel(output.data(), + /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), + activation_data.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); for (int i = 0; i < m * n; i++) { EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); } } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - k_eq_gs_32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + k_eq_gs_32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - large_k_n_gs32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + large_k_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - even_n_gs32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + even_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - clamp_k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + clamp_k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, - true /*has_clamp*/>( + false /*has_bias*/, true /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - m_clamp_k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + m_clamp_k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, - true /*has_clamp*/>( + false /*has_bias*/, true /*has_clamp*/>( /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); } template void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( - int m, - int k, - int n, - int group_size) { - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - /*weight_nbit=*/4, - /*has_weight_zeros=*/false, - has_bias, - has_clamp, - /*round_weight_scales_to_bf16=*/true); + int m, int k, int n, int group_size) { + auto test_case = + torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: + generate(m, k, n, group_size, + /*weight_nbit=*/4, + /*has_weight_zeros=*/false, has_bias, has_clamp, + /*round_weight_scales_to_bf16=*/true); using namespace torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; std::vector activation_data(activation_data_size(m, k, group_size)); - prepare_activation_data( - (void*)activation_data.data(), - m, - k, - group_size, - test_case.activations.data()); + prepare_activation_data((void *)activation_data.data(), m, k, group_size, + test_case.activations.data()); std::vector weight_data(weight_data_size(n, k, group_size)); - prepare_weight_data( - (void*)weight_data.data(), - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); + prepare_weight_data((void *)weight_data.data(), n, k, group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data(), + /*bias=*/test_case.bias.data()); std::vector output(m * n); - kernel( - output.data(), - /*output_m_stride=*/n, - m, - n, - k, - group_size, - weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); + kernel(output.data(), + /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), + activation_data.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); for (int i = 0; i < m * n; i++) { EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); } } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - k_eq_gs_32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + k_eq_gs_32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - large_k_n_gs32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + large_k_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - even_n_gs32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + even_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - clamp_k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + clamp_k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, - true /*has_clamp*/>( + false /*has_bias*/, true /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - m_clamp_k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + m_clamp_k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, - true /*has_clamp*/>( + false /*has_bias*/, true /*has_clamp*/>( /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); } #ifdef TORCHAO_ENABLE_ARM_I8MM template void test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( - int m, - int k, - int n, - int group_size) { - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - /*weight_nbit=*/4, - /*has_weight_zeros=*/false, - has_bias, - has_clamp, - /*round_weight_scales_to_bf16=*/true); + int m, int k, int n, int group_size) { + auto test_case = + torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: + generate(m, k, n, group_size, + /*weight_nbit=*/4, + /*has_weight_zeros=*/false, has_bias, has_clamp, + /*round_weight_scales_to_bf16=*/true); using namespace torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32; std::vector activation_data(activation_data_size(m, k, group_size)); - prepare_activation_data( - (void*)activation_data.data(), - m, - k, - group_size, - test_case.activations.data()); + prepare_activation_data((void *)activation_data.data(), m, k, group_size, + test_case.activations.data()); std::vector weight_data(weight_data_size(n, k, group_size)); - prepare_weight_data( - (void*)weight_data.data(), - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); + prepare_weight_data((void *)weight_data.data(), n, k, group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data(), + /*bias=*/test_case.bias.data()); std::vector output(m * n); - kernel( - output.data(), - /*output_m_stride=*/n, - m, - n, - k, - group_size, - weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); + kernel(output.data(), + /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), + activation_data.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); for (int i = 0; i < m * n; i++) { EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); } } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - k_eq_gs_32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + k_eq_gs_32) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - large_k_n_gs32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + large_k_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - even_n_gs32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + even_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - clamp_k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + clamp_k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, - true /*has_clamp*/>( + false /*has_bias*/, true /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - m_clamp_k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, + m_clamp_k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, - true /*has_clamp*/>( + false /*has_bias*/, true /*has_clamp*/>( /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); } template void test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( - int m, - int k, - int n, - int group_size) { - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - /*weight_nbit=*/4, - /*has_weight_zeros=*/false, - has_bias, - has_clamp, - /*round_weight_scales_to_bf16=*/true); + int m, int k, int n, int group_size) { + auto test_case = + torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: + generate(m, k, n, group_size, + /*weight_nbit=*/4, + /*has_weight_zeros=*/false, has_bias, has_clamp, + /*round_weight_scales_to_bf16=*/true); using namespace torchao::kernels::cpu::aarch64::kleidi:: kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32; std::vector activation_data(activation_data_size(m, k, group_size)); - prepare_activation_data( - (void*)activation_data.data(), - m, - k, - group_size, - test_case.activations.data()); + prepare_activation_data((void *)activation_data.data(), m, k, group_size, + test_case.activations.data()); std::vector weight_data(weight_data_size(n, k, group_size)); - prepare_weight_data( - (void*)weight_data.data(), - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); + prepare_weight_data((void *)weight_data.data(), n, k, group_size, + test_case.weight_qvals.data(), + test_case.weight_scales.data(), + /*weight_zeros=*/test_case.weight_zeros.data(), + /*bias=*/test_case.bias.data()); std::vector output(m * n); - kernel( - output.data(), - /*output_m_stride=*/n, - m, - n, - k, - group_size, - weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); + kernel(output.data(), + /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), + activation_data.data(), + /*clamp_min=*/test_case.clamp_min, + /*clamp_max=*/test_case.clamp_max); for (int i = 0; i < m * n; i++) { EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); } } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - k_eq_gs_32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + k_eq_gs_32) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - large_k_n_gs32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + large_k_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - even_n_gs32) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + even_n_gs32) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, - false /*has_clamp*/>( + false /*has_bias*/, false /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - clamp_k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + clamp_k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, - true /*has_clamp*/>( + false /*has_bias*/, true /*has_clamp*/>( /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); } -TEST( - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - m_clamp_k_eq_gs128) { +TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + m_clamp_k_eq_gs128) { test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, - true /*has_clamp*/>( + false /*has_bias*/, true /*has_clamp*/>( /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); } #endif // TORCHAO_ENABLE_ARM_I8MM diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index 4130d72e32..709386998e 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -4,23 +4,21 @@ // This source code is licensed under the license found in the // LICENSE file in the root directory of this source tree. +#include +#include +#include #include #include #include #include -#include -#include -#include namespace torchao::ops::linear_8bit_act_xbit_weight { PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( - const UKernelConfig& ukernel_config, - int n, - int target_panels_per_thread) { + const UKernelConfig &ukernel_config, int n, int target_panels_per_thread) { TORCHAO_CHECK(n >= 1, "n must be >= 1"); - TORCHAO_CHECK( - target_panels_per_thread >= 1, "target_panels_per_thread must be >= 1"); + TORCHAO_CHECK(target_panels_per_thread >= 1, + "target_panels_per_thread must be >= 1"); PackWeightDataTilingParams tiling_params; int nr = ukernel_config.nr; @@ -39,19 +37,15 @@ PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( return tiling_params; } -void pack_weight_data_operator( - const UKernelConfig& ukernel_config, - const PackWeightDataTilingParams& tiling_params, - // Outputs - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { +void pack_weight_data_operator(const UKernelConfig &ukernel_config, + const PackWeightDataTilingParams &tiling_params, + // Outputs + void *weight_data, + // Inputs + int n, int k, int group_size, + const int8_t *weight_qvals, + const float *weight_scales, + const int8_t *weight_zeros, const float *bias) { TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); @@ -71,27 +65,21 @@ void pack_weight_data_operator( int bias_offset = n_idx; ukernel_config.prepare_weight_data_fn( - (char*)weight_data + weight_data_offset, - /*n=*/nc_tile_size, - k, - group_size, - weight_qvals + weight_qvals_offset, + (char *)weight_data + weight_data_offset, + /*n=*/nc_tile_size, k, group_size, weight_qvals + weight_qvals_offset, weight_scales + weight_scales_and_zeros_offset, - weight_zeros + weight_scales_and_zeros_offset, - bias + bias_offset); + weight_zeros + weight_scales_and_zeros_offset, bias + bias_offset); }); } // This default mimics XNNPACK behavior if target_tiles_per_thread = 5 -LinearTilingParams get_default_linear_tiling_params( - const UKernelConfig& ukernel_config, - int m, - int n, - int target_tiles_per_thread) { +LinearTilingParams +get_default_linear_tiling_params(const UKernelConfig &ukernel_config, int m, + int n, int target_tiles_per_thread) { TORCHAO_CHECK(m >= 1, "m must be >= 1"); TORCHAO_CHECK(n >= 1, "n must be >= 1"); - TORCHAO_CHECK( - target_tiles_per_thread >= 1, "target_tiles_per_thread must be >= 1"); + TORCHAO_CHECK(target_tiles_per_thread >= 1, + "target_tiles_per_thread must be >= 1"); LinearTilingParams tiling_params; auto num_threads = torchao::get_num_threads(); @@ -122,41 +110,29 @@ namespace internal { inline size_t get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - int m, - int k, - int group_size) { + const UKernelConfig &ukernel_config, + const LinearTilingParams &tiling_params, int m, int k, int group_size) { return ukernel_config.activation_data_size_fn( tiling_params.mc_by_mr * ukernel_config.mr, k, group_size); } inline size_t get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - int m, - int k, - int group_size) { + const UKernelConfig &ukernel_config, + const LinearTilingParams &tiling_params, int m, int k, int group_size) { return ukernel_config.activation_data_size_fn(m, k, group_size); } inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - char* activation_data_buffer, + const UKernelConfig &ukernel_config, + const LinearTilingParams &tiling_params, char *activation_data_buffer, // Outputs - float* output, + float *output, // Inputs - int m, - int n, - int k, - int group_size, - const void* weight_data, - const float* activations, + int m, int n, int k, int group_size, const void *weight_data, + const float *activations, // Ignored if has_clamp = false - float clamp_min, - float clamp_max) { + float clamp_min, float clamp_max) { int nr = ukernel_config.nr; int mc = std::min(m, tiling_params.mc_by_mr * ukernel_config.mr); int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); @@ -169,12 +145,9 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( int m_idx = mc_tile_idx * mc; int mc_tile_size = std::min(mc, m - m_idx); int activations_offset = m_idx * k; - ukernel_config.prepare_activation_data_fn( - activation_data_buffer, - /*m=*/mc_tile_size, - k, - group_size, - activations + activations_offset); + ukernel_config.prepare_activation_data_fn(activation_data_buffer, + /*m=*/mc_tile_size, k, group_size, + activations + activations_offset); torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { int nc_tile_idx = idx; @@ -188,32 +161,21 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( output + output_offset, /*output_m_stride=*/n, /*m=*/mc_tile_size, - /*n=*/nc_tile_size, - k, - group_size, - /*weight_data=*/(char*)weight_data + weight_data_offset, - /*activation_data=*/activation_data_buffer, - clamp_min, - clamp_max); + /*n=*/nc_tile_size, k, group_size, + /*weight_data=*/(char *)weight_data + weight_data_offset, + /*activation_data=*/activation_data_buffer, clamp_min, clamp_max); }); } } inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - char* activation_data_buffer, + const UKernelConfig &ukernel_config, + const LinearTilingParams &tiling_params, char *activation_data_buffer, // Outputs - float* output, + float *output, // Inputs - int m, - int n, - int k, - int group_size, - const void* weight_data, - const float* activations, - float clamp_min, - float clamp_max) { + int m, int n, int k, int group_size, const void *weight_data, + const float *activations, float clamp_min, float clamp_max) { int mr = ukernel_config.mr; int nr = ukernel_config.nr; int mc = std::min(m, tiling_params.mc_by_mr * ukernel_config.mr); @@ -235,10 +197,7 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( ukernel_config.prepare_activation_data_fn( activation_data_buffer + activation_data_offset, - /*m=*/mc_tile_size, - k, - group_size, - activations + activations_offset); + /*m=*/mc_tile_size, k, group_size, activations + activations_offset); }); torchao::parallel_1d(0, num_mc_panels * num_nc_panels, [&](int64_t idx) { @@ -258,91 +217,59 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( output + output_offset, /*output_m_stride=*/n, /*m=*/mc_tile_size, - /*n=*/nc_tile_size, - k, - group_size, - /*weight_data=*/(char*)weight_data + weight_data_offset, + /*n=*/nc_tile_size, k, group_size, + /*weight_data=*/(char *)weight_data + weight_data_offset, /*activation_data=*/activation_data_buffer + activation_data_offset, - clamp_min, - clamp_max); + clamp_min, clamp_max); }); } } // namespace internal -void linear_operator( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - char* activation_data_buffer, - // Outputs - float* output, - // Inputs - int m, - int n, - int k, - int group_size, - const void* weight_data, - const float* activations, - // Ignored if has_clamp = false - float clamp_min, - float clamp_max) { +void linear_operator(const UKernelConfig &ukernel_config, + const LinearTilingParams &tiling_params, + LinearTileSchedulingPolicy scheduling_policy, + char *activation_data_buffer, + // Outputs + float *output, + // Inputs + int m, int n, int k, int group_size, + const void *weight_data, const float *activations, + // Ignored if has_clamp = false + float clamp_min, float clamp_max) { TORCHAO_CHECK(group_size % 16 == 0, "group_size must be a multiple of 16"); TORCHAO_CHECK(k % group_size == 0, "group_size must divide k"); switch (scheduling_policy) { - case LinearTileSchedulingPolicy::single_mc_parallel_nc: - internal::linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( - ukernel_config, - tiling_params, - activation_data_buffer, - output, - m, - n, - k, - group_size, - weight_data, - activations, - clamp_min, - clamp_max); - break; - case LinearTileSchedulingPolicy::parallel_mc_parallel_nc: - internal:: - linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( - ukernel_config, - tiling_params, - activation_data_buffer, - output, - m, - n, - k, - group_size, - weight_data, - activations, - clamp_min, - clamp_max); - break; - default: - TORCHAO_CHECK(false, "Unimplemented LinearTileSchedulingPolicy"); + case LinearTileSchedulingPolicy::single_mc_parallel_nc: + internal::linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( + ukernel_config, tiling_params, activation_data_buffer, output, m, n, k, + group_size, weight_data, activations, clamp_min, clamp_max); + break; + case LinearTileSchedulingPolicy::parallel_mc_parallel_nc: + internal::linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( + ukernel_config, tiling_params, activation_data_buffer, output, m, n, k, + group_size, weight_data, activations, clamp_min, clamp_max); + break; + default: + TORCHAO_CHECK(false, "Unimplemented LinearTileSchedulingPolicy"); } } -size_t get_activation_data_buffer_size( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - int m, - int k, - int group_size) { +size_t +get_activation_data_buffer_size(const UKernelConfig &ukernel_config, + const LinearTilingParams &tiling_params, + LinearTileSchedulingPolicy scheduling_policy, + int m, int k, int group_size) { switch (scheduling_policy) { - case LinearTileSchedulingPolicy::single_mc_parallel_nc: - return internal:: - get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( - ukernel_config, tiling_params, m, k, group_size); - case LinearTileSchedulingPolicy::parallel_mc_parallel_nc: - return internal:: - get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc( - ukernel_config, tiling_params, m, k, group_size); - default: - TORCHAO_CHECK(false, "Unimplemented LinearTileSchedulingPolicy"); + case LinearTileSchedulingPolicy::single_mc_parallel_nc: + return internal:: + get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( + ukernel_config, tiling_params, m, k, group_size); + case LinearTileSchedulingPolicy::parallel_mc_parallel_nc: + return internal:: + get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc( + ukernel_config, tiling_params, m, k, group_size); + default: + TORCHAO_CHECK(false, "Unimplemented LinearTileSchedulingPolicy"); } } diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index bcf9446f1b..1dc69dee74 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -5,41 +5,29 @@ // LICENSE file in the root directory of this source tree. #pragma once -#include #include +#include #include namespace torchao::ops::linear_8bit_act_xbit_weight { struct UKernelConfig { using activation_data_size_fn_type = size_t (*)(int m, int k, int group_size); - using prepare_activation_data_fn_type = void (*)( - void* activation_data, - int m, - int k, - int group_size, - const float* activations); + using prepare_activation_data_fn_type = void (*)(void *activation_data, int m, + int k, int group_size, + const float *activations); using weight_data_size_fn_type = size_t (*)(int n, int k, int group_size); - using prepare_weight_data_fn_type = void (*)( - void* weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias); - using kernel_fn_type = void (*)( - float* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max); + using prepare_weight_data_fn_type = void (*)(void *weight_data, int n, int k, + int group_size, + const int8_t *weight_qvals, + const float *weight_scales, + const int8_t *weight_zeros, + const float *bias); + using kernel_fn_type = void (*)(float *output, int output_m_stride, int m, + int n, int k, int group_size, + const void *weight_data, + const void *activation_data, float clamp_min, + float clamp_max); activation_data_size_fn_type activation_data_size_fn{nullptr}; // preferred_activation_data_alignment is only a preferred alignment for @@ -69,37 +57,30 @@ struct PackWeightDataTilingParams { int nc_by_nr{1}; }; -PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( - const UKernelConfig& ukernel_config, - int n, - int target_panels_per_thread = 1); +PackWeightDataTilingParams +get_default_pack_weight_data_tiling_params(const UKernelConfig &ukernel_config, + int n, + int target_panels_per_thread = 1); -inline size_t get_packed_weight_data_size( - const UKernelConfig& ukernel_config, - int n, - int k, - int group_size) { +inline size_t get_packed_weight_data_size(const UKernelConfig &ukernel_config, + int n, int k, int group_size) { return ukernel_config.weight_data_size_fn(n, k, group_size); } inline size_t get_preferred_packed_weight_data_alignment( - const UKernelConfig& ukernel_config) { + const UKernelConfig &ukernel_config) { return ukernel_config.preferred_weight_data_alignment; } -void pack_weight_data_operator( - const UKernelConfig& ukernel_config, - const PackWeightDataTilingParams& tiling_params, - // Outputs - void* weight_data, - // Inputs - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias); +void pack_weight_data_operator(const UKernelConfig &ukernel_config, + const PackWeightDataTilingParams &tiling_params, + // Outputs + void *weight_data, + // Inputs + int n, int k, int group_size, + const int8_t *weight_qvals, + const float *weight_scales, + const int8_t *weight_zeros, const float *bias); // Linear functions struct LinearTilingParams { @@ -107,46 +88,36 @@ struct LinearTilingParams { int nc_by_nr{1}; }; -LinearTilingParams get_default_linear_tiling_params( - const UKernelConfig& ukernel_config, - int m, - int n, - int target_tiles_per_thread = 5); +LinearTilingParams +get_default_linear_tiling_params(const UKernelConfig &ukernel_config, int m, + int n, int target_tiles_per_thread = 5); enum class LinearTileSchedulingPolicy { single_mc_parallel_nc, parallel_mc_parallel_nc }; -size_t get_activation_data_buffer_size( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - int m, - int k, - int group_size); +size_t +get_activation_data_buffer_size(const UKernelConfig &ukernel_config, + const LinearTilingParams &tiling_params, + LinearTileSchedulingPolicy scheduling_policy, + int m, int k, int group_size); inline size_t get_preferred_activation_data_buffer_alignment( - const UKernelConfig& ukernel_config) { + const UKernelConfig &ukernel_config) { return ukernel_config.preferred_activation_data_alignment; } -void linear_operator( - const UKernelConfig& ukernel_config, - const LinearTilingParams& tiling_params, - LinearTileSchedulingPolicy scheduling_policy, - char* activation_data_buffer, - // Outputs - float* output, - // Inputs - int m, - int n, - int k, - int group_size, - const void* weight_data, - const float* activations, - float clamp_min, - float clamp_max); +void linear_operator(const UKernelConfig &ukernel_config, + const LinearTilingParams &tiling_params, + LinearTileSchedulingPolicy scheduling_policy, + char *activation_data_buffer, + // Outputs + float *output, + // Inputs + int m, int n, int k, int group_size, + const void *weight_data, const float *activations, + float clamp_min, float clamp_max); } // namespace // torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index 52c3bbae12..bc88c0b725 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -10,11 +10,11 @@ #include #endif // defined(__aarch64__) || defined(__ARM_NEON) +#include #include #include #include #include -#include #include namespace { @@ -27,45 +27,39 @@ get_ukernel_config(torchao::ops::PackedWeightsHeader header) { switch (header.format) { #if defined(__aarch64__) || defined(__ARM_NEON) - case torchao::ops::PackedWeightsFormat:: - linear_8bit_act_xbit_weight_universal: - namespace ukernel - = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - - // Check packing params match the kernel - TORCHAO_CHECK( - header == - torchao::ops::linear_8bit_act_xbit_weight:: - get_packed_weights_header_universal( - weight_nbit, - has_weight_zeros, - has_bias, - /*nr=*/8, - /*kr=*/16), - "Packing params do not match what kernel supports"); - - config.packed_weights_header = header; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.preferred_activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.preferred_weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel:: - prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - return config; - break; + case torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal: + namespace ukernel + = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + + // Check packing params match the kernel + TORCHAO_CHECK(header == torchao::ops::linear_8bit_act_xbit_weight:: + get_packed_weights_header_universal( + weight_nbit, has_weight_zeros, has_bias, + /*nr=*/8, + /*kr=*/16), + "Packing params do not match what kernel supports"); + + config.packed_weights_header = header; + config.mr = 1; + config.nr = 8; + config.activation_data_size_fn = + &ukernel::activation_data_size; + config.preferred_activation_data_alignment = 16; // size of neon register + config.prepare_activation_data_fn = + &ukernel::prepare_activation_data; + config.weight_data_size_fn = + &ukernel::weight_data_size; + config.preferred_weight_data_alignment = 16; // size of neon register + config.prepare_weight_data_fn = + &ukernel::prepare_weight_data; + config.kernel_fn = + &ukernel::kernel; + return config; + break; #endif // defined(__aarch64__) || defined(__ARM_NEON) - default: - TORCHAO_CHECK(false, "Unsupported packed weights format"); + default: + TORCHAO_CHECK(false, "Unsupported packed weights format"); } } @@ -73,24 +67,22 @@ template inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig get_ukernel_config() { auto header = torchao::ops::linear_8bit_act_xbit_weight:: - get_packed_weights_header_universal( - weight_nbit, has_weight_zeros, has_bias, /*nr=*/8, /*kr=*/16); + get_packed_weights_header_universal(weight_nbit, has_weight_zeros, + has_bias, /*nr=*/8, /*kr=*/16); return get_ukernel_config( header); } #ifdef USE_ATEN template -Tensor pack_weights_cpu( - const Tensor& weight_qvals, - const Tensor& weight_scales, - const std::optional& weight_zeros, - int64_t group_size) { +Tensor pack_weights_cpu(const Tensor &weight_qvals, const Tensor &weight_scales, + const std::optional &weight_zeros, + int64_t group_size) { // TODO: add op support for bias static_assert(has_bias == false); - TORCHAO_CHECK( - weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8"); + TORCHAO_CHECK(weight_qvals.dtype() == torch::kInt8, + "weight_qvals must be int8"); TORCHAO_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D"); // In PyTorch, weights are nxk in row-major format (with activations being @@ -101,57 +93,45 @@ Tensor pack_weights_cpu( int n = weight_qvals.size(0); int k = weight_qvals.size(1); - TORCHAO_CHECK( - weight_scales.dtype() == torch::kFloat32, - "weight_scales must be float32"); + TORCHAO_CHECK(weight_scales.dtype() == torch::kFloat32, + "weight_scales must be float32"); TORCHAO_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D"); TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1"); - TORCHAO_CHECK( - weight_scales.size(0) == ((n * k) / group_size), - "expected 1 scale per group"); - - TORCHAO_CHECK( - has_weight_zeros == weight_zeros.has_value(), - "has_weight_zeros must match weight_zeros.has_value()"); - const int8_t* weight_zeros_ptr = nullptr; + TORCHAO_CHECK(weight_scales.size(0) == ((n * k) / group_size), + "expected 1 scale per group"); + + TORCHAO_CHECK(has_weight_zeros == weight_zeros.has_value(), + "has_weight_zeros must match weight_zeros.has_value()"); + const int8_t *weight_zeros_ptr = nullptr; if constexpr (has_weight_zeros) { - TORCHAO_CHECK( - weight_zeros.value().dtype() == torch::kInt8, - "weight_zeros must be int8"); + TORCHAO_CHECK(weight_zeros.value().dtype() == torch::kInt8, + "weight_zeros must be int8"); TORCHAO_CHECK(weight_zeros.value().dim() == 1, "weight_zeros must be 1D"); - TORCHAO_CHECK( - weight_zeros.value().size(0) == ((n * k) / group_size), - "expected 1 zero per group"); + TORCHAO_CHECK(weight_zeros.value().size(0) == ((n * k) / group_size), + "expected 1 zero per group"); weight_zeros_ptr = weight_zeros.value().const_data_ptr(); } using namespace torchao::ops::linear_8bit_act_xbit_weight; - auto ukernel_config = get_ukernel_config< - weight_nbit, - has_weight_zeros, - has_bias, - false /*has_clamp*/>(); + auto ukernel_config = get_ukernel_config(); auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params( ukernel_config, n, /*target_panels_per_thread=*/1); - auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + auto packed_weight_data_size = + torchao::ops::PackedWeightsHeader::size() + get_packed_weight_data_size(ukernel_config, n, k, group_size); Tensor packed_weights = torch::empty( {static_cast(packed_weight_data_size)}, torch::kInt8); ukernel_config.packed_weights_header.write( packed_weights.mutable_data_ptr()); pack_weight_data_operator( - ukernel_config, - pack_weight_tiling_params, + ukernel_config, pack_weight_tiling_params, packed_weights.mutable_data_ptr() + torchao::ops::PackedWeightsHeader::size(), - n, - k, - group_size, - weight_qvals.const_data_ptr(), - weight_scales.const_data_ptr(), - weight_zeros_ptr, + n, k, group_size, weight_qvals.const_data_ptr(), + weight_scales.const_data_ptr(), weight_zeros_ptr, /*bias*/ nullptr); return packed_weights; @@ -161,58 +141,51 @@ Tensor pack_weights_cpu( #ifdef USE_ATEN template Tensor pack_weights_without_zeros_cpu( - const Tensor& weight_qvals, - const Tensor& weight_scales, + const Tensor &weight_qvals, const Tensor &weight_scales, // TODO(T200095131): convert to int64_t when supported by AOTI // group_size is a tensor with size (0, group_size) - const Tensor& group_size_tensor) { + const Tensor &group_size_tensor) { int64_t group_size = group_size_tensor.size(1); - return pack_weights_cpu< - weight_nbit, - /*has_weight_zeros*/ false, - /*has_bias*/ false>( - weight_qvals, weight_scales, std::nullopt, group_size); + return pack_weights_cpu(weight_qvals, weight_scales, + std::nullopt, group_size); } #endif // USE_ATEN #ifdef USE_ATEN template Tensor pack_weights_with_zeros_cpu( - const Tensor& weight_qvals, - const Tensor& weight_scales, - const Tensor& weight_zeros, + const Tensor &weight_qvals, const Tensor &weight_scales, + const Tensor &weight_zeros, // TODO(T200095131): convert to int64_t when supported by AOTI // group_size is a meta tensor with size (group_size) - const Tensor& group_size_tensor) { + const Tensor &group_size_tensor) { int64_t group_size = group_size_tensor.size(1); - return pack_weights_cpu< - weight_nbit, - /*has_weight_zeros*/ true, - /*has_bias*/ false>( - weight_qvals, weight_scales, weight_zeros, group_size); + return pack_weights_cpu(weight_qvals, weight_scales, + weight_zeros, group_size); } #endif // USE_ATEN #ifdef USE_ATEN template -Tensor pack_weights_meta( - const Tensor& weight_qvals, - const Tensor& weight_scales, - const std::optional& weight_zeros, - int64_t group_size) { +Tensor pack_weights_meta(const Tensor &weight_qvals, + const Tensor &weight_scales, + const std::optional &weight_zeros, + int64_t group_size) { TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1"); int n = weight_qvals.size(0); int k = weight_qvals.size(1); using namespace torchao::ops::linear_8bit_act_xbit_weight; - auto ukernel_config = get_ukernel_config< - weight_nbit, - has_weight_zeros, - has_bias, - false /*has_clamp*/>(); + auto ukernel_config = get_ukernel_config(); - auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + + auto packed_weight_data_size = + torchao::ops::PackedWeightsHeader::size() + get_packed_weight_data_size(ukernel_config, n, k, group_size); return torch::empty({static_cast(packed_weight_data_size)}) .to("meta"); @@ -222,50 +195,43 @@ Tensor pack_weights_meta( #ifdef USE_ATEN template Tensor pack_weights_without_zeros_meta( - const Tensor& weight_qvals, - const Tensor& weight_scales, + const Tensor &weight_qvals, const Tensor &weight_scales, // TODO(T200095131): convert to int64_t when supported by AOTI // group_size is a meta tensor with size (group_size) - const Tensor& group_size_tensor) { + const Tensor &group_size_tensor) { int64_t group_size = group_size_tensor.size(1); - return pack_weights_meta< - weight_nbit, - /*has_weight_zeros*/ false, - /*has_bias*/ false>( - weight_qvals, weight_scales, std::nullopt, group_size); + return pack_weights_meta(weight_qvals, weight_scales, + std::nullopt, group_size); } #endif // USE_ATEN #ifdef USE_ATEN template Tensor pack_weights_with_zeros_meta( - const Tensor& weight_qvals, - const Tensor& weight_scales, - const Tensor& weight_zeros, + const Tensor &weight_qvals, const Tensor &weight_scales, + const Tensor &weight_zeros, // TODO(T200095131): convert to int64_t when supported by AOTI // group_size is a meta tensor with size (group_size) - const Tensor& group_size_tensor) { + const Tensor &group_size_tensor) { int64_t group_size = group_size_tensor.size(1); - return pack_weights_meta< - weight_nbit, - /*has_weight_zeros*/ true, - /*has_bias*/ false>( - weight_qvals, weight_scales, weight_zeros, group_size); + return pack_weights_meta(weight_qvals, weight_scales, + weight_zeros, group_size); } #endif // USE_ATEN #if defined(USE_ATEN) || defined(USE_EXECUTORCH) template -Tensor linear_out_cpu( - const Tensor& activations, - const Tensor& packed_weights, - // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to - // int64_t when supported by AOTI Currently they are tensors with size - // equal to (0, the int they wrap) - const Tensor& group_size_tensor, - const Tensor& n_tensor, - const Tensor& k_tensor, - Tensor& out) { +Tensor +linear_out_cpu(const Tensor &activations, const Tensor &packed_weights, + // TODO(T200095131): convert n_tensor, k_tensor, + // group_size_tensor to int64_t when supported by AOTI Currently + // they are tensors with size equal to (0, the int they wrap) + const Tensor &group_size_tensor, const Tensor &n_tensor, + const Tensor &k_tensor, Tensor &out) { int n = n_tensor.size(1); int k = k_tensor.size(1); int group_size = group_size_tensor.size(1); @@ -274,15 +240,15 @@ Tensor linear_out_cpu( TORCHAO_CHECK(group_size >= 1, "group_size must be >= 1"); #ifdef USE_ATEN - TORCHAO_CHECK( - activations.dtype() == torch::kFloat32, "activations must be float32"); + TORCHAO_CHECK(activations.dtype() == torch::kFloat32, + "activations must be float32"); #endif // USE_ATEN TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D"); int m = activations.size(0); int k_ = activations.size(1); - TORCHAO_CHECK( - k == k_, "activation shape is incompatible with packed weights."); + TORCHAO_CHECK(k == k_, + "activation shape is incompatible with packed weights."); #ifdef USE_ATEN TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32"); @@ -302,55 +268,40 @@ Tensor linear_out_cpu( TORCHAO_CHECK(packed_weights.dim() == 1, "packed_weights must be 1D"); #ifdef USE_ATEN - TORCHAO_CHECK( - packed_weights.dtype() == torch::kInt8, "packed_weights must be int8"); + TORCHAO_CHECK(packed_weights.dtype() == torch::kInt8, + "packed_weights must be int8"); #endif // USE_ATEN - TORCHAO_CHECK( - packed_weights.size(0) >= torchao::ops::PackedWeightsHeader::size(), - "packed_weights is not big enough to read the header."); + TORCHAO_CHECK(packed_weights.size(0) >= + torchao::ops::PackedWeightsHeader::size(), + "packed_weights is not big enough to read the header."); auto header = torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); - auto ukernel_config = get_ukernel_config< - weight_nbit, - has_weight_zeros /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>(header); - - auto linear_tiling_params = get_default_linear_tiling_params( - ukernel_config, - m, - n, - /*target_tiles_per_thread=*/5); + auto ukernel_config = + get_ukernel_config(header); + + auto linear_tiling_params = + get_default_linear_tiling_params(ukernel_config, m, n, + /*target_tiles_per_thread=*/5); auto linear_scheduling_policy = LinearTileSchedulingPolicy::single_mc_parallel_nc; auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - m, - k, + ukernel_config, linear_tiling_params, linear_scheduling_policy, m, k, group_size); std::vector activation_data_buffer(activation_data_buffer_size); - linear_operator( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - activation_data_buffer.data(), - out.mutable_data_ptr(), - m, - n, - k, - group_size, - packed_weights.const_data_ptr() + - torchao::ops::PackedWeightsHeader::size(), - activations.const_data_ptr(), - // Clamp parameters are ignored because config is created from - // has_clamp = false - /*clamp_min=*/0.0, - /*clamp_max=*/0.0); + linear_operator(ukernel_config, linear_tiling_params, + linear_scheduling_policy, activation_data_buffer.data(), + out.mutable_data_ptr(), m, n, k, group_size, + packed_weights.const_data_ptr() + + torchao::ops::PackedWeightsHeader::size(), + activations.const_data_ptr(), + // Clamp parameters are ignored because config is created from + // has_clamp = false + /*clamp_min=*/0.0, + /*clamp_max=*/0.0); return out; } @@ -358,23 +309,17 @@ Tensor linear_out_cpu( #ifdef USE_ATEN template -Tensor linear_cpu( - const Tensor& activations, - const Tensor& packed_weights, - // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to - // int64_t when supported by AOTI Currently they are tensors with size - // equal to (0, the int they wrap) - const Tensor& group_size_tensor, - const Tensor& n_tensor, - const Tensor& k_tensor) { +Tensor +linear_cpu(const Tensor &activations, const Tensor &packed_weights, + // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to + // int64_t when supported by AOTI Currently they are tensors with + // size equal to (0, the int they wrap) + const Tensor &group_size_tensor, const Tensor &n_tensor, + const Tensor &k_tensor) { Tensor output_tensor = torch::empty({}, torch::kFloat32); - linear_out_cpu( - activations, - packed_weights, - group_size_tensor, - n_tensor, - k_tensor, - output_tensor); + linear_out_cpu(activations, packed_weights, + group_size_tensor, n_tensor, + k_tensor, output_tensor); return output_tensor; } #endif // USE_ATEN @@ -382,14 +327,12 @@ Tensor linear_cpu( #ifdef USE_ATEN template Tensor linear_meta( - const Tensor& activations, - const Tensor& packed_weights, + const Tensor &activations, const Tensor &packed_weights, // TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to // int64_t when supported by AOTI // Currently they are tensors with size equal to (0, the int they wrap) - const Tensor& group_size_tensor, - const Tensor& n_tensor, - const Tensor& k_tensor) { + const Tensor &group_size_tensor, const Tensor &n_tensor, + const Tensor &k_tensor) { int n = n_tensor.size(1); int k = k_tensor.size(1); TORCHAO_CHECK(n >= 1, "n must be >= 1"); @@ -398,8 +341,8 @@ Tensor linear_meta( TORCHAO_CHECK(activations.dim() == 2, "activations must be 2D"); int m = activations.size(0); int k_ = activations.size(1); - TORCHAO_CHECK( - k == k_, "activation shape is incompatible with packed weights."); + TORCHAO_CHECK(k == k_, + "activation shape is incompatible with packed weights."); return torch::empty({m, n}).to("meta"); } #endif // USE_ATEN diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index 932ecac4b2..bcf746e00e 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -15,10 +15,10 @@ #if defined(TORCHAO_ENABLE_KLEIDI) #include #include -#if defined (TORCHAO_ENABLE_ARM_I8MM) +#if defined(TORCHAO_ENABLE_ARM_I8MM) #include #include -#endif // TORCHAO_ENABLE_ARM_I8MM +#endif // TORCHAO_ENABLE_ARM_I8MM #endif // TORCHAO_ENABLE_KLEIDI const float kTol = 1.0e-5; @@ -49,27 +49,24 @@ UKernelConfig get_ukernel_config() { return config; } -template -void test_linear_8bit_act_xbit_weight(int m, int n, int k, int group_size, const UKernelConfig* ukernel_config_arg = nullptr) { +template +void test_linear_8bit_act_xbit_weight( + int m, int n, int k, int group_size, + const UKernelConfig *ukernel_config_arg = nullptr) { UKernelConfig ukernel_config; if (ukernel_config_arg != nullptr) { ukernel_config = *ukernel_config_arg; } else { - ukernel_config = - get_ukernel_config(); + ukernel_config = get_ukernel_config(); } - auto test_case = torchao:: - channelwise_8bit_activation_groupwise_lowbit_weight_test_case::generate( - m, - k, - n, - group_size, - weight_nbit, - has_weight_zeros, - has_bias, - has_clamp, - /*round_weight_scales_to_bf16=*/has_kleidi); + auto test_case = + torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: + generate(m, k, n, group_size, weight_nbit, has_weight_zeros, has_bias, + has_clamp, + /*round_weight_scales_to_bf16=*/has_kleidi); auto output = std::vector(m * n); @@ -91,27 +88,17 @@ void test_linear_8bit_act_xbit_weight(int m, int n, int k, int group_size, const preferred_packed_weight_data_alignment, packed_weight_data_size); pack_weight_data_operator( - ukernel_config, - pack_weight_data_tiling_params, - packed_weight_data.get(), - n, - k, - group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - test_case.weight_zeros.data(), - test_case.bias.data()); + ukernel_config, pack_weight_data_tiling_params, + packed_weight_data.get(), n, k, group_size, + test_case.weight_qvals.data(), test_case.weight_scales.data(), + test_case.weight_zeros.data(), test_case.bias.data()); // Allocate activation buffer auto linear_tiling_params = get_default_linear_tiling_params(ukernel_config, m, n); auto activation_data_buffer_size = get_activation_data_buffer_size( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - m, - k, + ukernel_config, linear_tiling_params, linear_scheduling_policy, m, k, group_size); auto activation_data_buffer_alignment = get_preferred_activation_data_buffer_alignment(ukernel_config); @@ -119,20 +106,11 @@ void test_linear_8bit_act_xbit_weight(int m, int n, int k, int group_size, const activation_data_buffer_alignment, activation_data_buffer_size); // Run linear - linear_operator( - ukernel_config, - linear_tiling_params, - linear_scheduling_policy, - activation_data_buffer.get(), - output.data(), - m, - n, - k, - group_size, - packed_weight_data.get(), - test_case.activations.data(), - test_case.clamp_min, - test_case.clamp_max); + linear_operator(ukernel_config, linear_tiling_params, + linear_scheduling_policy, activation_data_buffer.get(), + output.data(), m, n, k, group_size, + packed_weight_data.get(), test_case.activations.data(), + test_case.clamp_min, test_case.clamp_max); // Test correctness for (int i = 0; i < m * n; i++) { @@ -145,90 +123,86 @@ void test_linear_8bit_act_xbit_weight(int m, int n, int k, int group_size, const #if defined(TORCHAO_ENABLE_KLEIDI) enum kai_kernel_id { - dotprod_1x4x32 = 0, - dotprod_1x8x32, - i8mm_4x8x32, - i8mm_8x4x32 + dotprod_1x4x32 = 0, + dotprod_1x8x32, + i8mm_4x8x32, + i8mm_8x4x32 }; -#define KAI_GEN_UKERNEL(kernel_ns) \ - namespace kernel = kernel_ns; \ - auto uk = kernel::get_ukernel(); \ - config.mr = uk.get_m_step(); \ - config.nr = uk.get_n_step(); \ - config.activation_data_size_fn = &kernel::activation_data_size; \ - config.weight_data_size_fn = &kernel::weight_data_size; \ - config.preferred_activation_data_alignment = kernel::get_preferred_alignement(); \ - config.preferred_weight_data_alignment = kernel::get_preferred_alignement(); \ - config.prepare_activation_data_fn = &kernel::prepare_activation_data; \ - config.prepare_weight_data_fn = &kernel::prepare_weight_data; \ - config.kernel_fn = &kernel::kernel; \ - -template -UKernelConfig get_ukernel_config_kleidi() { - UKernelConfig config; -#if defined (TORCHAO_ENABLE_ARM_I8MM) - if constexpr (kernel_id == i8mm_4x8x32) { - KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32); - return config; - } - if constexpr (kernel_id == i8mm_8x4x32) { - KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32); - return config; - } +#define KAI_GEN_UKERNEL(kernel_ns) \ + namespace kernel = kernel_ns; \ + auto uk = kernel::get_ukernel(); \ + config.mr = uk.get_m_step(); \ + config.nr = uk.get_n_step(); \ + config.activation_data_size_fn = &kernel::activation_data_size; \ + config.weight_data_size_fn = &kernel::weight_data_size; \ + config.preferred_activation_data_alignment = \ + kernel::get_preferred_alignement(); \ + config.preferred_weight_data_alignment = kernel::get_preferred_alignement(); \ + config.prepare_activation_data_fn = &kernel::prepare_activation_data; \ + config.prepare_weight_data_fn = &kernel::prepare_weight_data; \ + config.kernel_fn = &kernel::kernel; + +template UKernelConfig get_ukernel_config_kleidi() { + UKernelConfig config; +#if defined(TORCHAO_ENABLE_ARM_I8MM) + if constexpr (kernel_id == i8mm_4x8x32) { + KAI_GEN_UKERNEL( + torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32); + return config; + } + if constexpr (kernel_id == i8mm_8x4x32) { + KAI_GEN_UKERNEL( + torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32); + return config; + } #endif // TORCHAO_ENABLE_ARM_I8MM - if constexpr (kernel_id == dotprod_1x8x32) { - KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32); - return config; - } - KAI_GEN_UKERNEL(torchao::kernels::cpu::aarch64::kleidi::kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32); + if constexpr (kernel_id == dotprod_1x8x32) { + KAI_GEN_UKERNEL( + torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32); return config; + } + KAI_GEN_UKERNEL( + torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32); + return config; } #endif // TORCHAO_ENABLE_KLEIDI TEST(test_linear_8bit_act_xbit_weight, Standard) { - test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/>( + test_linear_8bit_act_xbit_weight<4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, false /*has_clamp*/>( /*m=*/13, /*n=*/8 * 10 + 3, /*k=*/16 * 3, /*group_size=*/16); } TEST(test_linear_8bit_act_xbit_weight, HasWeightZeros) { - test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - true /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( + test_linear_8bit_act_xbit_weight<4 /*weight_nbit*/, true /*has_weight_zeros*/, + true /*has_bias*/, false /*has_clamp*/>( /*m=*/13, /*n=*/8 * 10 + 3, /*k=*/16 * 3, /*group_size=*/16); } TEST(test_linear_8bit_act_xbit_weight, HasBias) { - test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/>( + test_linear_8bit_act_xbit_weight<4 /*weight_nbit*/, + false /*has_weight_zeros*/, + true /*has_bias*/, false /*has_clamp*/>( /*m=*/13, /*n=*/8 * 10 + 3, /*k=*/16 * 3, /*group_size=*/16); } TEST(test_linear_8bit_act_xbit_weight, HasClamp) { - test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/>( + test_linear_8bit_act_xbit_weight<4 /*weight_nbit*/, + false /*has_weight_zeros*/, + false /*has_bias*/, true /*has_clamp*/>( /*m=*/13, /*n=*/8 * 10 + 3, /*k=*/16 * 3, /*group_size=*/16); } TEST(test_linear_8bit_act_xbit_weight, SmallDimension) { - test_linear_8bit_act_xbit_weight< - 3 /*weight_nbit*/, - true /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/>( + test_linear_8bit_act_xbit_weight<3 /*weight_nbit*/, true /*has_weight_zeros*/, + true /*has_bias*/, true /*has_clamp*/>( /*m=*/1, /*n=*/1, /*k=*/16 * 3, /*group_size=*/16); } @@ -236,23 +210,17 @@ TEST(test_linear_8bit_act_xbit_weight, KNotDivisibleByGroupSize) { int n = 1; int k = 16 + 1; int group_size = 16; - auto ukernel_config = get_ukernel_config< - 3 /*weight_nbit*/, - true /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/>(); + auto ukernel_config = + get_ukernel_config<3 /*weight_nbit*/, true /*has_weight_zeros*/, + true /*has_bias*/, true /*has_clamp*/>(); auto pack_weight_data_tiling_params = get_default_pack_weight_data_tiling_params(ukernel_config, n); EXPECT_THROW( { pack_weight_data_operator( - ukernel_config, - pack_weight_data_tiling_params, - /*packed_weight_data=*/nullptr, - n, - k, - group_size, + ukernel_config, pack_weight_data_tiling_params, + /*packed_weight_data=*/nullptr, n, k, group_size, /*weight_qvals=*/nullptr, /*weight_scales=*/nullptr, /*weight_zeros=*/nullptr, @@ -266,23 +234,17 @@ TEST(test_linear_8bit_act_xbit_weight, GroupSizeNotDivisibleBy16) { int k = 20; int group_size = 10; - auto ukernel_config = get_ukernel_config< - 3 /*weight_nbit*/, - true /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/>(); + auto ukernel_config = + get_ukernel_config<3 /*weight_nbit*/, true /*has_weight_zeros*/, + true /*has_bias*/, true /*has_clamp*/>(); auto pack_weight_data_tiling_params = get_default_pack_weight_data_tiling_params(ukernel_config, n); EXPECT_THROW( { pack_weight_data_operator( - ukernel_config, - pack_weight_data_tiling_params, - /*packed_weight_data=*/nullptr, - n, - k, - group_size, + ukernel_config, pack_weight_data_tiling_params, + /*packed_weight_data=*/nullptr, n, k, group_size, /*weight_qvals=*/nullptr, /*weight_scales=*/nullptr, /*weight_zeros=*/nullptr, @@ -298,1395 +260,1072 @@ TEST(test_linear_8bit_act_xbit_weight, GroupSizeNotDivisibleBy16) { #if defined(TORCHAO_ENABLE_KLEIDI) /*****************/ -// dotprod_1x4x32 tests +// dotprod_1x4x32 tests /*****************/ - TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn6xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m1xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn4xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m1xn4xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn6xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m1xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn22xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn26xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m1xn26xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn102xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m1xn102xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn222xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn22xk128xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m1xn22xk128xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn26xk64xg64_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m1xn26xk64xg64_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn34xk128xg64) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m2xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/2, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m2xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/2, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m3xn6xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m3xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m4xn8xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m4xn8xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/4, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m3xn6xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m3xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m31xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/31, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m32xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/32, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m33xn6xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m33xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/33, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m34xn8xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m34xn8xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/34, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m35xn6xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m35xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/35, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m7xn22xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/7, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m17xn26xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m17xn26xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/17, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m23xn102xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m23xn102xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/23, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m41xn222xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/19, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m23xn22xk128xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m23xn22xk128xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/23, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m29xn26xk64xg64_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m29xn26xk64xg64_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/29, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m101xn34xk128xg64) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m101xn34xk128xg64) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/101, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); } - - - /*****************/ -// dotprod_1x8x32 tests +// dotprod_1x8x32 tests /*****************/ - TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn6xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m1xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn4xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m1xn4xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn6xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m1xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn22xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn26xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m1xn26xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn102xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m1xn102xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn222xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn22xk128xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m1xn22xk128xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn26xk64xg64_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m1xn26xk64xg64_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn34xk128xg64) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m2xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/2, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m2xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/2, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m3xn6xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m3xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m4xn8xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m4xn8xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/4, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m3xn6xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m3xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m31xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/31, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m32xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/32, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m33xn6xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m33xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/33, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m34xn8xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m34xn8xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/34, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m35xn6xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m35xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/35, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m7xn22xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/7, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m17xn26xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m17xn26xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/17, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m23xn102xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m23xn102xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/23, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m41xn222xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/19, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m23xn22xk128xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m23xn22xk128xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/23, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m29xn26xk64xg64_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m29xn26xk64xg64_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/29, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m101xn34xk128xg64) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m101xn34xk128xg64) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/101, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); } - - - /*****************/ -// i8mm_4x8x32 tests +// i8mm_4x8x32 tests /*****************/ #if defined(TORCHAO_ENABLE_ARM_I8MM) TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn4xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m1xn4xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn22xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn26xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn102xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m1xn102xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn222xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn22xk128xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m1xn22xk128xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn26xk64xg64_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m1xn26xk64xg64_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn34xk128xg64) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m2xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/2, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m2xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/2, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m3xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m4xn8xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m4xn8xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/4, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m3xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m31xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/31, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m32xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/32, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m33xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/33, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m34xn8xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m34xn8xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/34, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m35xn6xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m35xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/35, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m7xn22xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/7, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m17xn26xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m17xn26xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/17, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m23xn102xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m23xn102xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/23, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m41xn222xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/19, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m23xn22xk128xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m23xn22xk128xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/23, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m29xn26xk64xg64_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m29xn26xk64xg64_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/29, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m101xn34xk128xg64) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/101, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); } #endif // TORCHAO_ENABLE_ARM_I8MM - /*****************/ -// i8mm_8x4x32 tests +// i8mm_8x4x32 tests /*****************/ #if defined(TORCHAO_ENABLE_ARM_I8MM) TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn4xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m1xn4xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn22xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn26xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn102xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m1xn102xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn222xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn22xk128xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m1xn22xk128xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn26xk64xg64_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m1xn26xk64xg64_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn34xk128xg64) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/1, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m2xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/2, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m2xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/2, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m3xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m4xn8xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m4xn8xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/4, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m3xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/3, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m31xn2xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/31, /*n=*/2, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m32xn4xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/32, /*n=*/4, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m33xn6xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/33, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m34xn8xk32xg32_bias_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m34xn8xk32xg32_bias_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/34, /*n=*/8, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m35xn6xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m35xn6xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/35, /*n=*/6, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m7xn22xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/7, /*n=*/22, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m17xn26xk32xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m17xn26xk32xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/17, /*n=*/26, /*k=*/32, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m23xn102xk32xg32_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m23xn102xk32xg32_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/23, /*n=*/102, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m41xn222xk32xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/19, /*n=*/14, /*k=*/64, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m23xn22xk128xg32_bias) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m23xn22xk128xg32_bias) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - true /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/23, /*n=*/22, /*k=*/128, /*group_size=*/32, &ukernel_config); } -TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m29xn26xk64xg64_clamp) { +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m29xn26xk64xg64_clamp) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - true /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( /*m=*/29, /*n=*/26, /*k=*/64, /*group_size=*/64, &ukernel_config); } TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m101xn34xk128xg64) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< - 4 /*weight_nbit*/, - false /*has_weight_zeros*/, - false /*has_bias*/, - false /*has_clamp*/, - true /*has_kleidi*/>( + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( /*m=*/101, /*n=*/34, /*k=*/128, /*group_size=*/64, &ukernel_config); } From aa9b9c90249763809c907d856d612b1662b8f9ae Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 17 Feb 2025 18:35:03 -0800 Subject: [PATCH 086/115] Fix `DDP` with `nf4` (#1684) * implement aten.cat.default for nf4 * add nf4 ddp tests * run ruff * add dtype check * formatting * run ruff format on nf4tensor --------- Co-authored-by: Mark Saroufim --- test/dtypes/ddp/check_ddp_nf4.py | 40 +++++++ test/dtypes/ddp/ddp_nf4.py | 155 ++++++++++++++++++++++++++++ test/dtypes/ddp/run_ddp_nf4_test.sh | 48 +++++++++ torchao/dtypes/nf4tensor.py | 30 ++++++ 4 files changed, 273 insertions(+) create mode 100644 test/dtypes/ddp/check_ddp_nf4.py create mode 100644 test/dtypes/ddp/ddp_nf4.py create mode 100755 test/dtypes/ddp/run_ddp_nf4_test.sh diff --git a/test/dtypes/ddp/check_ddp_nf4.py b/test/dtypes/ddp/check_ddp_nf4.py new file mode 100644 index 0000000000..608bcb9c02 --- /dev/null +++ b/test/dtypes/ddp/check_ddp_nf4.py @@ -0,0 +1,40 @@ +import argparse +from pathlib import Path + +import torch + +from torchao.dtypes.nf4tensor import NF4Tensor + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ref_checkpoint_dir", type=str, required=True) + parser.add_argument("--test_checkpoints_dir", type=str, required=True) + + args = parser.parse_args() + + ref_checkpoints = list(Path(args.ref_checkpoint_dir).glob("*.pt")) + assert len(ref_checkpoints) == 1, "Expected exactly one reference checkpoint" + ref_checkpoint = ref_checkpoints[0] + ref_state_dict = torch.load(ref_checkpoint, weights_only=True, map_location="cpu") + print(f"Ref checkpoint: {ref_checkpoint}") + + for path in Path(args.test_checkpoints_dir).glob("*.pt"): + print(f"Checking {path}") + state_dict = torch.load(path, weights_only=True, map_location="cpu") + assert ref_state_dict.keys() == state_dict.keys() + for name in ref_state_dict.keys(): + ref_param = ref_state_dict[name] + test_param = state_dict[name] + print(f"Checking {name} {type(ref_param)} {type(test_param)}") + + if isinstance(ref_param, NF4Tensor): + ref_param = ref_param.get_original_weight() + assert isinstance(test_param, NF4Tensor) + test_param = test_param.get_original_weight() + + if not torch.allclose(ref_param, test_param, atol=1e-4, rtol=1e-4): + diff = (ref_param - test_param).abs().max() + print(f" \u2718 Param {name} differs by {diff}") + else: + print(f" \u2713 Param {name} is consistent") + print("Passed!") diff --git a/test/dtypes/ddp/ddp_nf4.py b/test/dtypes/ddp/ddp_nf4.py new file mode 100644 index 0000000000..e38d0015b1 --- /dev/null +++ b/test/dtypes/ddp/ddp_nf4.py @@ -0,0 +1,155 @@ +import argparse +import math +import os +import time +from contextlib import contextmanager + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch._dynamo import config as dynamo_config +from torch.nn.parallel import DistributedDataParallel as DDP + +from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 + + +class LoRALinear(nn.Module): + def __init__( + self, + hidden_dim: int, + lora_rank: int = None, + lora_alpha: float = 16, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + self.hidden_dim = hidden_dim + if lora_rank is None: + lora_rank = hidden_dim // 2 + + weight = torch.randn(hidden_dim, hidden_dim, dtype=dtype) + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha + self.register_parameter( + "weight", nn.Parameter(to_nf4(weight), requires_grad=False) + ) + self.lora_a = nn.Linear( + in_features=hidden_dim, out_features=self.lora_rank, bias=False + ) + self.lora_b = nn.Linear( + in_features=self.lora_rank, out_features=hidden_dim, bias=False + ) + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.lora_b.weight, a=math.sqrt(5)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = linear_nf4(input=x, weight=self.weight) + lora_out = self.lora_a(x) + lora_out = (self.lora_alpha / self.lora_rank) * self.lora_b(lora_out) + return out + lora_out + + +def _init_model(dim, num_linears, device, dtype) -> nn.Module: + with torch.device(device): + modules = [] + for i in range(num_linears): + modules += [LoRALinear(hidden_dim=dim, dtype=dtype)] + seq = nn.Sequential(*modules) + + return seq + + +def dist_print(*args, delay=0.5): + rank = dist.get_rank() + time.sleep(delay * rank) + print(f"[rank{rank}]: ", *args, flush=True) + + +def make_batch(global_bs, dim, dtype, device): + batch = torch.randn((global_bs, dim), dtype=dtype, device=device) + if dist.get_world_size() > 1: + batch = batch.chunk(dist.get_world_size(), dim=0)[dist.get_rank()] + return batch + + +def run_ddp(global_bs, dim, num_linears, device, dtype, num_steps, save_dir, compile): + os.makedirs(save_dir, exist_ok=True) + model = _init_model(dim, num_linears, device, dtype) + model = DDP(model, device_ids=[device]) + + if compile: + model = torch.compile(model) + optim = torch.optim.Adam(model.parameters(), lr=1e-2) + + losses = [] + + for i in range(num_steps): + inp = make_batch(global_bs, dim, dtype, device) + loss = model(inp).sum() + losses.append(loss) + loss.backward() + optim.step() + optim.zero_grad() + + dist.barrier() + + save_path = f"{save_dir}/ddp-{dist.get_rank()}.pt" + torch.save(model.state_dict(), save_path) + dist_print("Saved model to", save_path) + + +def init_dist(): + dist.init_process_group(backend="nccl") + torch.cuda.set_device(dist.get_rank()) + dist_print("Dist initialized with world size", dist.get_world_size()) + + +def cleanup_dist(): + dist.barrier() + if dist.get_rank() == 0: + print("Cleaning up dist") + dist.destroy_process_group() + + +@contextmanager +def distributed_context(): + init_dist() + yield + cleanup_dist() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--global_bs", type=int, default=8) + parser.add_argument("--dim", type=int, default=128) + parser.add_argument("--num_linears", type=int, default=1) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--num_steps", type=int, default=3) + parser.add_argument("--save_dir", type=str, default="checkpoints") + parser.add_argument("--compile", action="store_true") + parser.add_argument("--optimize_ddp", type=str, default="ddp_optimizer") + args = parser.parse_args() + + args.dtype = getattr(torch, args.dtype) + dynamo_config.optimize_ddp = args.optimize_ddp + + if args.optimize_ddp == "python_reducer": + dynamo_config.compiled_autograd = True + + with distributed_context(): + torch.manual_seed(args.seed) + run_ddp( + global_bs=args.global_bs, + dim=args.dim, + num_linears=args.num_linears, + device=args.device, + dtype=args.dtype, + num_steps=args.num_steps, + save_dir=args.save_dir, + compile=args.compile, + ) diff --git a/test/dtypes/ddp/run_ddp_nf4_test.sh b/test/dtypes/ddp/run_ddp_nf4_test.sh new file mode 100755 index 0000000000..b9a3c2929f --- /dev/null +++ b/test/dtypes/ddp/run_ddp_nf4_test.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +set -euo pipefail +WORLD_SIZE=${1:-2} + + +# Test params +GLOBAL_BS=8 +DIM=128 +NUM_LINEARS=1 +NUM_STEPS=3 + +PARAMS="--global_bs $GLOBAL_BS --dim $DIM --num_linears $NUM_LINEARS --num_steps $NUM_STEPS" +SAVE_DIR="checkpoints" +REF_DIR="${SAVE_DIR}/ref" +TEST_DIR="${SAVE_DIR}/test" +DDP_PROGRAM="ddp_nf4.py" +CHECK_PROGRAM="check_ddp_nf4.py" +REF_CMD="torchrun --nproc_per_node 1 $DDP_PROGRAM $PARAMS --save_dir $REF_DIR" +TEST_CMD="torchrun --nproc_per_node $WORLD_SIZE $DDP_PROGRAM $PARAMS --save_dir $TEST_DIR" +CHECK_CMD="python $CHECK_PROGRAM --ref_checkpoint_dir $REF_DIR --test_checkpoints_dir $TEST_DIR" +CLEANUP_CMD="rm -rf $SAVE_DIR" + +echo "Step 1: Generating reference checkpoint..." +echo $REF_CMD +$REF_CMD +echo -e "\n --- \n" +sleep 2 + +echo "Step 2: Generating test checkpoints..." +echo $TEST_CMD +$TEST_CMD +echo -e "\n --- \n" +sleep 2 + +# Check params +echo "Step 3: Checking params..." +echo $CHECK_CMD +$CHECK_CMD +echo -e "\n --- \n" +sleep 2 + +# Cleanup +echo "Step 4: Cleaning up..." +echo $CLEANUP_CMD +$CLEANUP_CMD +echo -e "\n --- \n" +echo "Done!" diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 5ae06a1fe1..457cf352fa 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -423,6 +423,35 @@ def nf4_pin_memory(aten_op, args, kwargs=None): return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) +@implements( + [ + aten.cat.default, + ] +) +def nf4_cat(aten_op: torch._ops.OpOverload, args, kwargs=None): + tensors_to_cat = args[0] + assert all(isinstance(t, torch.Tensor) for t in tensors_to_cat) + remaining_args = args[1:] + + ts = [] + for t in tensors_to_cat: + assert isinstance(t, torch.Tensor) + + if isinstance(t, NF4Tensor): + ts.append(t.get_original_weight()) + else: + ts.append(t) + + dtype = ts[0].dtype + assert all(t.dtype == dtype for t in ts) + + if kwargs is None: + kwargs = {} + + tensors = aten_op(ts, *remaining_args, **kwargs) + return tensors + + @dataclass(frozen=True) class SubclassTensorArgs: original_shape: torch.Size @@ -1058,3 +1087,4 @@ def nf4_constructor( if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals([NF4Tensor]) + torch.serialization.add_safe_globals([NF4Tensor]) From f2e8f5683a95b51feba3287a36d3c54d07b137be Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Tue, 18 Feb 2025 12:31:43 -0500 Subject: [PATCH 087/115] notify on wheel failure for aarch, m1, windows (#1725) * notify on build_wheels_windows.yml failure * notify on build_wheels_aarch64_linux.yml failure * Update build-wheels_m1.yml * testing change build_wheels_aarch64_linux.yml * Update build_wheels_aarch64_linux.yml * Update build-wheels_m1.yml * Update build_wheels_aarch64_linux.yml --- .github/workflows/build-wheels_m1.yml | 31 ++++++++++++++++ .../workflows/build_wheels_aarch64_linux.yml | 31 ++++++++++++++++ .github/workflows/build_wheels_windows.yml | 35 +++++++++++++++++++ 3 files changed, 97 insertions(+) diff --git a/.github/workflows/build-wheels_m1.yml b/.github/workflows/build-wheels_m1.yml index 93c8086a23..33a44191c5 100644 --- a/.github/workflows/build-wheels_m1.yml +++ b/.github/workflows/build-wheels_m1.yml @@ -41,3 +41,34 @@ jobs: runner-type: macos-m1-stable smoke-test-script: test/smoke_test.py trigger-event: ${{ github.event_name }} + notify: + runs-on: ubuntu-latest + name: Email notification + needs: [generate-matrix, build] + if: failure() && github.event_name == 'schedule' + steps: + - uses: dawidd6/action-send-mail@v4 + with: + server_address: smtp.gmail.com + server_port: 465 + username: torchao.notify + password: ${{ secrets.TORCHAO_NOTIFY_PASSWORD }} + from: torchao.notify@gmail.com + to: ${{ secrets.TORCHAO_NOTIFY_RECIPIENT }} + subject: Scheduled Build Failure for TorchAO + body: | + Build Failure Notification for TorchAO + A failure occurred in the Build Linux Wheels workflow. + Run Details: + - Workflow: ${{ github.workflow }} + - Run Type: ${{ github.event_name }} + - Repository: ${{ github.repository }} + - Branch/PR: ${{ github.ref }} + - Commit: ${{ github.sha }} + You can view the full run details here: + ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + Error Information: + ${{ needs.generate-matrix.result == 'failure' && 'Matrix generation failed' || '' }} + ${{ needs.build.result == 'failure' && 'Build job failed' || '' }} + + This is an automated notification. Please check the GitHub Actions page for more details about the failure. diff --git a/.github/workflows/build_wheels_aarch64_linux.yml b/.github/workflows/build_wheels_aarch64_linux.yml index 0f64aa53bf..9d54cda112 100644 --- a/.github/workflows/build_wheels_aarch64_linux.yml +++ b/.github/workflows/build_wheels_aarch64_linux.yml @@ -54,3 +54,34 @@ jobs: setup-miniconda: false secrets: PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + notify: + runs-on: ubuntu-latest + name: Email notification + needs: [generate-matrix, build] + if: failure() && github.event_name == 'schedule' + steps: + - uses: dawidd6/action-send-mail@v4 + with: + server_address: smtp.gmail.com + server_port: 465 + username: torchao.notify + password: ${{ secrets.TORCHAO_NOTIFY_PASSWORD }} + from: torchao.notify@gmail.com + to: ${{ secrets.TORCHAO_NOTIFY_RECIPIENT }} + subject: Scheduled Build Failure for TorchAO + body: | + Build Failure Notification for TorchAO + A failure occurred in the Build AARCH64 Wheels workflow. + Run Details: + - Workflow: ${{ github.workflow }} + - Run Type: ${{ github.event_name }} + - Repository: ${{ github.repository }} + - Branch/PR: ${{ github.ref }} + - Commit: ${{ github.sha }} + You can view the full run details here: + ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + Error Information: + ${{ needs.generate-matrix.result == 'failure' && 'Matrix generation failed' || '' }} + ${{ needs.build.result == 'failure' && 'Build job failed' || '' }} + + This is an automated notification. Please check the GitHub Actions page for more details about the failure. diff --git a/.github/workflows/build_wheels_windows.yml b/.github/workflows/build_wheels_windows.yml index bfb22cab3d..01db4b9d86 100644 --- a/.github/workflows/build_wheels_windows.yml +++ b/.github/workflows/build_wheels_windows.yml @@ -60,3 +60,38 @@ jobs: package-name: ${{ matrix.package-name }} smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} + notify: + runs-on: ubuntu-latest + name: Email notification + needs: [generate-matrix, build] + if: failure() && github.event_name == 'schedule' + steps: + - uses: dawidd6/action-send-mail@v4 + with: + server_address: smtp.gmail.com + server_port: 465 + username: torchao.notify + password: ${{ secrets.TORCHAO_NOTIFY_PASSWORD }} + from: torchao.notify@gmail.com + to: ${{ secrets.TORCHAO_NOTIFY_RECIPIENT }} + subject: Scheduled Build Failure for TorchAO + body: | + Build Failure Notification for TorchAO + + A failure occurred in the Build Windows Wheels workflow. + + Run Details: + - Workflow: ${{ github.workflow }} + - Run Type: ${{ github.event_name }} + - Repository: ${{ github.repository }} + - Branch/PR: ${{ github.ref }} + - Commit: ${{ github.sha }} + + You can view the full run details here: + ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + + Error Information: + ${{ needs.generate-matrix.result == 'failure' && 'Matrix generation failed' || '' }} + ${{ needs.build.result == 'failure' && 'Build job failed' || '' }} + + This is an automated notification. Please check the GitHub Actions page for more details about the failure. From 7b37eb07c0996760697cba6578a4e9071dac1dd8 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 18 Feb 2025 10:37:50 -0800 Subject: [PATCH 088/115] Make TorchAO cpp/Python extension Differential Revision: D69634772 Pull Request resolved: https://github.com/pytorch/ao/pull/1719 --- test/dtypes/test_affine_quantized.py | 3 ++ test/quantization/test_marlin_qqq.py | 7 +-- test/test_ops.py | 7 +-- torchao/__init__.py | 54 ++++++++----------- .../rowwise_scaled_linear_cutlass_s4s4.cu | 10 ++-- .../rowwise_scaled_linear_cutlass_s8s4.cu | 12 +++-- 6 files changed, 42 insertions(+), 51 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 616701f1e3..112cab8684 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -23,6 +23,7 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + is_fbcode, is_sm_at_least_89, ) @@ -213,6 +214,8 @@ class TestAffineQuantizedBasic(TestCase): @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) def test_flatten_unflatten(self, device, dtype): + if device == "cuda" and dtype == torch.bfloat16 and is_fbcode(): + raise unittest.SkipTest("TODO: Failing for cuda + bfloat16 in fbcode") apply_quant_list = get_quantization_functions(False, True, device) for apply_quant in apply_quant_list: linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index ebdf2281e0..1fd60acb52 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -1,5 +1,4 @@ import copy -import unittest import pytest import torch @@ -19,13 +18,9 @@ MappingType, choose_qparams_and_quantize_affine_qqq, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 -@unittest.skipIf( - is_fbcode(), - "Skipping the test in fbcode since we don't have TARGET file for kernels", -) class TestMarlinQQQ(TestCase): def setUp(self): super().setUp() diff --git a/test/test_ops.py b/test/test_ops.py index 54efefb026..b3b160e85f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -18,12 +18,7 @@ ) from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode - -if is_fbcode(): - pytest.skip( - "Skipping the test in fbcode since we don't have TARGET file for kernels" - ) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff try: import torchao.ops diff --git a/torchao/__init__.py b/torchao/__init__.py index 11716da62e..cc453e2d14 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -9,7 +9,6 @@ "ignore", message="Failed to initialize NumPy: No module named 'numpy'" ) - # We use this "hack" to set torchao.__version__ correctly # the version of ao is dependent on environment variables for multiple architectures # For local development this will default to whatever is version.txt @@ -21,34 +20,28 @@ except PackageNotFoundError: __version__ = "unknown" # In case this logic breaks don't break the build -_IS_FBCODE = ( - hasattr(torch._utils_internal, "IS_FBSOURCE") and torch._utils_internal.IS_FBSOURCE -) -if not _IS_FBCODE: - try: - from pathlib import Path - - so_files = list(Path(__file__).parent.glob("_C*.so")) - if len(so_files) > 0: - assert ( - len(so_files) == 1 - ), f"Expected one _C*.so file, found {len(so_files)}" - torch.ops.load_library(so_files[0]) - from . import ops - - # The following library contains CPU kernels from torchao/experimental - # They are built automatically by ao/setup.py if on an ARM machine. - # They can also be built outside of the torchao install process by - # running the script `torchao/experimental/build_torchao_ops.sh ` - # For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md - experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*")) - if len(experimental_lib) > 0: - assert ( - len(experimental_lib) == 1 - ), f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}" - torch.ops.load_library(experimental_lib[0]) - except: - logging.debug("Skipping import of cpp extensions") +try: + from pathlib import Path + + so_files = list(Path(__file__).parent.glob("_C*.so")) + if len(so_files) > 0: + assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" + torch.ops.load_library(str(so_files[0])) + from . import ops + + # The following library contains CPU kernels from torchao/experimental + # They are built automatically by ao/setup.py if on an ARM machine. + # They can also be built outside of the torchao install process by + # running the script `torchao/experimental/build_torchao_ops.sh ` + # For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md + experimental_lib = list(Path(__file__).parent.glob("libtorchao_ops_aten.*")) + if len(experimental_lib) > 0: + assert ( + len(experimental_lib) == 1 + ), f"Expected at most one libtorchao_ops_aten.* file, found {len(experimental_lib)}" + torch.ops.load_library(str(experimental_lib[0])) +except: + logging.debug("Skipping import of cpp extensions") from torchao.quantization import ( autoquant, @@ -64,6 +57,3 @@ "testing", "ops", ] - -# test-pytorchbot -# test-codev diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu index e455b7bdf2..cc1b5ca123 100644 --- a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s4s4.cu @@ -14,10 +14,14 @@ rowwise_scaled_linear_cutlass_s4s4( " for xq and ", wq.dtype(), " for wq is not supported"); // Dispatch to appropriate kernel template. - using ElementA = cutlass::int4b_t; - using ElementB = cutlass::int4b_t; - return rowwise_scaled_linear_cutlass( + #if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) + // We get ElementA/ElementB types from the header + return rowwise_scaled_linear_cutlass( xq, x_scale, wq, w_scale, bias); + #else + TORCH_CHECK(false, "CUTLASS kernels not built - rowwise_scaled_linear_cutlass_s4s4 not available"); + return at::Tensor{}; + #endif } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { diff --git a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu index 680822ca7f..29f30d08fc 100644 --- a/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu +++ b/torchao/csrc/cuda/rowwise_scaled_linear_cutlass/rowwise_scaled_linear_cutlass_s8s4.cu @@ -1,5 +1,4 @@ #include - #include "rowwise_scaled_linear_cutlass.cuh" namespace torchao { @@ -13,11 +12,16 @@ rowwise_scaled_linear_cutlass_s8s4( __func__, " : The input datatypes combination ", xq.dtype(), " for xq and ", wq.dtype(), " for wq is not supported"); - // Dispatch to appropriate kernel template. +#if defined(BUILD_ROWWISE_SCALED_LINEAR_CUTLASS) + // Define ElementA as int8_t since it's a standard type using ElementA = int8_t; - using ElementB = cutlass::int4b_t; - return rowwise_scaled_linear_cutlass( + // ElementB comes from cutlass header + return rowwise_scaled_linear_cutlass( xq, x_scale, wq, w_scale, bias); +#else + TORCH_CHECK(false, "CUTLASS kernels not built - rowwise_scaled_linear_cutlass_s8s4 not available"); + return at::Tensor{}; +#endif } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { From 988c5c97800d1d8570b80d428cea9cf81e1c24c7 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 18 Feb 2025 13:13:28 -0800 Subject: [PATCH 089/115] fix tensor parallelism for float8 training with rowwise scaling (#1718) Summary: 1. add a test for toy model + TP + float8 rowwise scaling training 2. fix underlying issues to make the test pass: a. add fast path for tensor view where the new shape is the same as old shape, for rowwise scaled float8 (this is needed for DTensor) b. modify the fake grad dependency workaround to work when grad is a DTensor Test Plan: 1. ./test/float8/test_everything.sh (one transient failure: https://www.internalfb.com/phabricator/paste/view/P1733103301) 2. verified that float8 rowwise scaling behaves sanely in torchtitan on LLaMa 3 8B on 8 H100s, with tp 2: ``` // requires https://github.com/pytorch/torchtitan/pull/808 // baseline - bfloat16 + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:41:16,175 - root - INFO - step: 40 loss: 7.4240 memory: 35.56GiB(37.43%) tps: 1,669 mfu: 9.77% // float8 baseline - float8 tensorwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:44:07,806 - root - INFO - step: 40 loss: 7.4993 memory: 35.57GiB(37.44%) tps: 2,141 mfu: 12.54% // float8 rowwise without zero fake dep (for sanity) + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:47:51,400 - root - INFO - step: 40 loss: 7.3472 memory: 35.55GiB(37.42%) tps: 1,858 mfu: 10.88% // float8 rowwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:51:20,864 - root - INFO - step: 40 loss: 9.4211 memory: 35.55GiB(37.42%) tps: 1,820 mfu: 10.66% ``` Reviewers: Subscribers: Tasks: Tags: --- test/float8/test_dtensor.py | 91 +++++++++++++++++------- torchao/float8/float8_linear.py | 6 +- torchao/float8/float8_ops.py | 17 ++++- torchao/float8/float8_tensor_parallel.py | 39 ++++++---- 4 files changed, 113 insertions(+), 40 deletions(-) diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index 41b21e4406..d0f34da0a9 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -23,7 +23,12 @@ from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_tensor from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.distributed.tensor.parallel import parallelize_module +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + PrepareModuleInput, + RowwiseParallel, + parallelize_module, +) from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, Transformer, @@ -31,7 +36,13 @@ from tqdm import tqdm from torchao.float8 import Float8LinearConfig -from torchao.float8.config import CastConfig, ScalingType, e4m3_dtype +from torchao.float8.config import ( + CastConfig, + Float8LinearRecipeName, + ScalingType, + e4m3_dtype, + recipe_name_to_linear_config, +) from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic from torchao.float8.float8_tensor import ( @@ -49,6 +60,8 @@ from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor from torchao.testing.float8.dtensor_utils import ToyModel +torch.set_float32_matmul_precision("high") + def setup_distributed(): world_size = int(os.environ.get("WORLD_SIZE", -1)) @@ -180,13 +193,17 @@ def _test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): def _test_fp8_mlp_tensor_parallelism_base( - mesh: DeviceMesh, size=16, compile: bool = False + mesh: DeviceMesh, size=16, compile: bool = False, rowwise: bool = False ): device = mesh.device_type - # For now, only supports dynamic scaling of `x` and `dL_dY`. - # TODO(future): add support for float8 all-gather with delayed scaling - # for activations and gradients. - config = Float8LinearConfig(emulate=True) + + if rowwise: + config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE) + # hack around config being frozen + # TODO(future PR): we should make this nicer at the config level + object.__setattr__(config, "emulate", True) + else: + config = Float8LinearConfig(emulate=True) toy_model = ToyModel().to(device) toy_model_fp8 = convert_to_float8_training(toy_model, config=config) @@ -196,14 +213,28 @@ def _test_fp8_mlp_tensor_parallelism_base( sp_model = copy.deepcopy(toy_model) sp_model = convert_to_float8_training(sp_model, config=config) + # For tensorwise scaling, enable float8 all_gather. + # For rowwise scaling, keep high precision all_gather. Motivation for + # not doing float8 all-gather for rowwise: tensors need to be scaled both ways, + # so for float8 all-gather we'd need to send two float8 copies per tensor, + # which is similar # bytes over the wire than just doing bfloat16 all-gather. + if rowwise: + colwise_parallel_cls = ColwiseParallel + rowwise_parallel_cls = RowwiseParallel + prepare_input_cls = PrepareModuleInput + else: + colwise_parallel_cls = Float8ColwiseParallel + rowwise_parallel_cls = Float8RowwiseParallel + prepare_input_cls = PrepareFloat8ModuleInput + # vanilla TP tp_model = parallelize_module( tp_model, mesh, { - "ffn.w1": Float8ColwiseParallel(), - "ffn.w2": Float8ColwiseParallel(), - "ffn.out_proj": Float8RowwiseParallel(), + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls(), }, ) @@ -212,33 +243,41 @@ def _test_fp8_mlp_tensor_parallelism_base( sp_model, mesh, { - "ffn": PrepareFloat8ModuleInput( + "ffn": prepare_input_cls( input_layouts=Shard(1), desired_input_layouts=Replicate() ), - "ffn.w1": Float8ColwiseParallel(), - "ffn.w2": Float8ColwiseParallel(), - "ffn.out_proj": Float8RowwiseParallel( + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls( output_layouts=Shard(1), use_local_output=False ), }, ) - # PrepareFloat8ModuleInput with specific submodule fqn + # prepare_input_cls with specific submodule fqn sp_model2 = copy.deepcopy(toy_model) sp_model2 = convert_to_float8_training(sp_model2, config=config) + if rowwise: + prepare_input = prepare_input_cls( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + ) + else: + prepare_input = prepare_input_cls( + input_layouts=Shard(1), + desired_input_layouts=Replicate(), + fwd_config_submodule_fqn="w2", + ) + sp_model2 = parallelize_module( sp_model2, mesh, { - "ffn": PrepareFloat8ModuleInput( - input_layouts=Shard(1), - desired_input_layouts=Replicate(), - fwd_config_submodule_fqn="w2", - ), - "ffn.w1": Float8ColwiseParallel(), - "ffn.w2": Float8ColwiseParallel(), - "ffn.out_proj": Float8RowwiseParallel( + "ffn": prepare_input, + "ffn.w1": colwise_parallel_cls(), + "ffn.w2": colwise_parallel_cls(), + "ffn.out_proj": rowwise_parallel_cls( output_layouts=Shard(1), use_local_output=False ), }, @@ -278,11 +317,13 @@ def _test_fp8_mlp_tensor_parallelism_base( def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False) + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=False) + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False, rowwise=True) def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): - _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=False) + _test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True, rowwise=True) def _test_distribute_fsdp_tensor_subclass(tp_mesh: DeviceMesh): diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index 0bc2690bc5..d822d33042 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -168,8 +168,10 @@ def backward(ctx, grad_output): ): # workaround from https://github.com/pytorch/pytorch/issues/141881 # to avoid saving float8 weight from forward to backward when - # FSDP is on - weight_hp_t = weight_hp_t + (grad_output_reshaped[0, 0] * 0) + # FSDP is on: add a fake dependency on `grad_output`. + g_reshaped = grad_output.reshape(-1, grad_output.shape[-1]) * 0 + zero = g_reshaped[:1] * 0 + weight_hp_t = weight_hp_t + zero # Note: we need https://github.com/pytorch/pytorch/issues/136267 # to be solved to have a chance to reuse max(abs(weight, dim=...)) diff --git a/torchao/float8/float8_ops.py b/torchao/float8/float8_ops.py index 2af4160de4..36abd9dbc4 100644 --- a/torchao/float8/float8_ops.py +++ b/torchao/float8/float8_ops.py @@ -113,11 +113,25 @@ def float8_transpose(aten_op, args, kwargs=None): @implements([aten.view.default]) def float8_view(aten_op, args, kwargs=None): + t, new_shape = args[0], args[1] + + # if the new shape is the same as old, return an equivalent tensor + # note that we have to create a new wrapper to make PyTorch internals happy + if new_shape == list(t._data.shape): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + return Float8Tensor( + new_data, + args[0]._scale, + args[0]._orig_dtype, + args[0]._linear_mm_config, + args[0]._gemm_input_role, + args[0]._axiswise_dim, + ) + if len(args[0]._scale.shape) < 2: # tensorwise scaling return float8_desugar_op(aten_op, args, kwargs) - t, new_shape = args[0], args[1] # for now, only support reshaping to [-1, dim] or [dim, -1] axiswise_dim = t._axiswise_dim if len(new_shape) == 2: @@ -146,6 +160,7 @@ def float8_view(aten_op, args, kwargs=None): t._gemm_input_role, new_axiswise_dim, ) + raise AssertionError( f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} t._axiswise_dim {t._axiswise_dim} new_shape {new_shape} is not supported yet." ) diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index 9d45196cf3..a52b38b6bf 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -36,6 +36,11 @@ def _float8_linear_supports_float8_allgather(m): class Float8ColwiseParallel(ColwiseParallel): + """ + Like `ColwiseParallel`, but with all-gather in float8. This + currently assumes tensorwise scaling. + """ + @staticmethod def _prepare_input_fn( input_layouts, desired_input_layouts, mod, inputs, device_mesh @@ -96,6 +101,11 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: class Float8RowwiseParallel(RowwiseParallel): + """ + Like `RowwiseParallel`, but with all-gather in float8. This + currently assumes tensorwise scaling. + """ + @staticmethod def _prepare_input_fn( input_layouts, desired_input_layouts, mod, inputs, device_mesh @@ -154,18 +164,23 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: class PrepareFloat8ModuleInput(PrepareModuleInput): - # subclass the PrepareModuleInput classes to implement fp8 specific logic, the only difference is that - # after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor) - # This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate) - # so that if there are multiple float8 users of the input activation, we perform fp8 allgather - # only once. - # FP8 Args: - # float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input, - # we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn - # fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used - # for the float8 cast. If not specified, we will search for the Float8Linear in the submodules - # and use the forward config from that module, in this case all module's forward config must be - # the same. + """ + Like `PrepareModuleInput`, but with all-gather in float8. This + currently assumes tensorwise scaling. + + The only difference from `PrepareModuleInput` is that + after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor) + This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate) + so that if there are multiple float8 users of the input activation, we perform fp8 allgather + only once. + FP8 Args: + float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input, + we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn + fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used + for the float8 cast. If not specified, we will search for the Float8Linear in the submodules + and use the forward config from that module, in this case all module's forward config must be + the same. + """ def __init__( self, From 79ac44ea22d91efcbe67778eaa5aca67103aa73f Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 15:17:03 -0800 Subject: [PATCH 090/115] Promote Supermask out of prototype (#1729) This PR promotes Supermask and block sparsity from prototype -> `torchao.sparsity`, instead of the `apply_supermask` function which was previously closely coupled with SAM. It adds a new public API for `SupermaskLinear`, which users can use to add Supermask to their models for training with ``` sparsify_(model, lambda x: SupermaskLinear.from_linear(x, block_size=64, sparsity_level=0.9) ``` To accelerate for inference, we convert the `SupermaskLinear` model back into a `nn.Linear`, which simplifies the Supermask logic: ``` sparsify_(model, lambda x: SupermaskLinear.to_linear(x, sparsity_level=0.9) ``` **bc-breaking** The previous prototype APIs, `torchao.sparsity.prototype.superblock.supermask` and `torchao.prototype.sparsity.superblock.supermask` have been deprecated. You can use `torchao.sparsity.supermask` instead. --- test/sparsity/test_supermask.py | 61 +++ .../sparsity/superblock/supermask.py | 365 ------------------ torchao/sparsity/__init__.py | 2 + .../prototype/superblock/supermask.py | 6 +- torchao/sparsity/supermask.py | 148 +++++++ 5 files changed, 212 insertions(+), 370 deletions(-) create mode 100644 test/sparsity/test_supermask.py delete mode 100644 torchao/prototype/sparsity/superblock/supermask.py create mode 100644 torchao/sparsity/supermask.py diff --git a/test/sparsity/test_supermask.py b/test/sparsity/test_supermask.py new file mode 100644 index 0000000000..fa86850a07 --- /dev/null +++ b/test/sparsity/test_supermask.py @@ -0,0 +1,61 @@ +import logging +import unittest + +import pytest +import torch +from torch import nn +from torch.testing._internal import common_utils + +from torchao.sparsity import sparsify_ + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + + +class TestSupermask(common_utils.TestCase): + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @common_utils.parametrize("sparsity_level", [0.25, 0.5]) + @common_utils.parametrize("blocksize", [2, 4, 8]) + def test_supermask(self, sparsity_level, blocksize): + model = ( + nn.Sequential( + nn.Linear(16, 16, bias=False), + ) + .half() + .cuda() + .eval() + ) + + from torchao.sparsity import SupermaskLinear + + M, N = model[0].weight.shape + sparsify_( + model, + lambda x: SupermaskLinear.from_linear( + x, sparsity_level=sparsity_level, blocksize=blocksize + ), + ) + sparsify_(model, SupermaskLinear.to_linear) + weight_bsr = model[0].weight.to_sparse_bsr(blocksize=blocksize) + + # Test correct sparsity level + nnz = weight_bsr._nnz() + expected = round((M // blocksize) * (N // blocksize) * (1 - sparsity_level)) + assert nnz == expected, f"Expected {expected} nonzeros, got {nnz}" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + def test_from_linear(self): + from torchao.sparsity import SupermaskLinear + + linear = nn.Linear(128, 128) + supermask_linear = SupermaskLinear.from_linear( + linear, sparsity_level=0.5, blocksize=4 + ) + assert supermask_linear.weight.shape == linear.weight.shape + + +common_utils.instantiate_parametrized_tests(TestSupermask) + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/prototype/sparsity/superblock/supermask.py b/torchao/prototype/sparsity/superblock/supermask.py deleted file mode 100644 index abd23c566e..0000000000 --- a/torchao/prototype/sparsity/superblock/supermask.py +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -# original supermask -scores_min = None -scores_max = 9e9 -uniform_init_01 = False - -# adjusted supermask, initialize scores with uniform distribution in [0,1], clamp scores in each step in [0,1] -# scores_min=0. -# scores_max=1. -# uniform_init_01 = True - - -def percentile(t, q): - """Return the value that is larger than q% of t""" - k = 1 + round(0.01 * float(q) * (t.numel() - 1)) - return t.view(-1).kthvalue(k).values - - -class GetSubnet(torch.autograd.Function): - """Supermask STE function""" - - @staticmethod - def forward(ctx, scores, zeros, ones, sparsity): - clamped_scores = scores.clamp(min=scores_min, max=scores_max) - k_val = percentile(clamped_scores, sparsity * 100) - return torch.where( - clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device) - ) - - @staticmethod - def backward(ctx, g): - return g, None, None, None - - -class SupermaskLinear(nn.Linear): - """Supermask class for Linear layer""" - - def __init__( - self, - sparsity, - fixed_mask, - fixed_weight, - bitwidth, - transform, - fixed_transform, - *args, - **kwargs, - ): - tile_size = kwargs.pop("tile_size", 1) - super(SupermaskLinear, self).__init__(*args, **kwargs) - # initialize the scores - max_sparsity = 1 - ( - 1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()]) - ) - self.sparsity = sparsity - if self.sparsity > max_sparsity: - print( - f"reducing sparsity from {self.sparsity} to {max_sparsity}", - f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})", - ) - self.sparsity = max_sparsity - self.tile_size = tile_size - self.sparsify_weights = False - self.scores = nn.Parameter( - torch.empty( - [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] - ), - requires_grad=not fixed_mask, - ) - nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_( - self.scores, a=math.sqrt(5) - ) - - # the shift and the scale are transformation parameters - # the actually used weights = self.weight*self.scale+self.shift - # the transformation is activated only for quantized weights - self.shift = nn.Parameter(torch.Tensor(1).fill_(0.0), requires_grad=False) - self.scale = nn.Parameter(torch.Tensor(1).fill_(1.0), requires_grad=False) - - with torch.no_grad(): - # if bitwidth is None, then use floating point values in self.weight - # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) - # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 - # these quantized values are uniformly distributed - if bitwidth is not None: - weights_max = torch.max(self.weight).item() - weights_min = torch.min(self.weight).item() - least_step = (weights_max - weights_min) / pow(2, bitwidth) - left_bound = weights_min - 1e-6 - right_bound = weights_min + least_step + 1e-6 - # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1] ), requires_grad=not fixed_transform[1]) - # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; - self.shift = nn.Parameter( - torch.Tensor(1).fill_( - 0.0 if transform[0] is None else transform[0] - ), - requires_grad=not fixed_transform[0], - ) - self.scale = nn.Parameter( - torch.Tensor(1).fill_( - 1.0 if transform[1] is None else transform[1] - ), - requires_grad=not fixed_transform[1], - ) - for i in range(-int(pow(2, bitwidth - 1)), int(pow(2, bitwidth - 1))): - self.weight[ - torch.logical_and( - self.weight > left_bound, self.weight <= right_bound - ) - ] = i - left_bound = right_bound - right_bound += least_step - - self.weight.requires_grad = not fixed_weight - - def get_mask(self): - subnet = GetSubnet.apply( - self.scores, - torch.zeros_like(self.scores), - torch.ones_like(self.scores), - self.sparsity, - ) - - if self.tile_size != 1: - for i, k in enumerate(self.weight.shape): - subnet = subnet.repeat_interleave(self.tile_size, dim=i) - subnet = torch.narrow(subnet, i, 0, k) - - return subnet - - def sparsify_offline(self): - subnet = self.get_mask() - self.weight.data = (self.weight * self.scale + self.shift) * subnet - self.sparsify_weights = True - - def forward(self, x): - if not self.sparsify_weights: - subnet = self.get_mask() - w = (self.weight * self.scale + self.shift) * subnet - else: - w = self.weight - return F.linear(x, w, self.bias) - - -class SupermaskConv2d(nn.Conv2d): - """Supermask class for Conv2d layer""" - - def __init__( - self, - sparsity, - fixed_mask, - fixed_weight, - bitwidth, - transform, - fixed_transform, - *args, - **kwargs, - ): - tile_size = kwargs.pop("tile_size", 1) - super(SupermaskConv2d, self).__init__(*args, **kwargs) - # initialize the scores - max_sparsity = 1 - ( - 1 / math.prod([math.ceil(k / tile_size) for k in self.weight.size()]) - ) - self.sparsity = sparsity - if self.sparsity > max_sparsity: - print( - f"reducing sparsity from {self.sparsity} to {max_sparsity}", - f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {tile_size})", - ) - self.sparsity = max_sparsity - self.tile_size = tile_size - self.scores = nn.Parameter( - torch.empty( - [max(1, int(math.ceil(wn / tile_size))) for wn in self.weight.size()] - ), - requires_grad=not fixed_mask, - ) - nn.init.uniform_(self.scores) if uniform_init_01 else nn.init.kaiming_uniform_( - self.scores, a=math.sqrt(5) - ) - - # the shift and the scale are transformation parameters - # the actually used weights = self.weight*self.scale+self.shift - # the transformation is activated only for quantized weights - self.shift = nn.Parameter(torch.Tensor(1).fill_(0.0), requires_grad=False) - self.scale = nn.Parameter(torch.Tensor(1).fill_(1.0), requires_grad=False) - - with torch.no_grad(): - # if bitwidth is None, then use floating point values in self.weight - # if bitwidth is not None, then quantize self.weight into k-bit (k=bitwidth) - # quantized values are -2^(k-1), -2^(k-1)+1, ..., 0, 1, ..., 2^(k-1)-1 - # these quantized values are uniformly distributed - if bitwidth is not None: - weights_max = torch.max(self.weight).item() - weights_min = torch.min(self.weight).item() - least_step = (weights_max - weights_min) / pow(2, bitwidth) - left_bound = weights_min - 1e-6 - right_bound = weights_min + least_step + 1e-6 - # self.shift=nn.Parameter(torch.Tensor(1).fill_( (weights_min+(pow(2,bitwidth-1)+0.5)*least_step) if transform[0] is None else transform[0] ), requires_grad=not fixed_transform[0]) - # self.scale=nn.Parameter(torch.Tensor(1).fill_( least_step if transform[1] is None else transform[1]), requires_grad=not fixed_transform[1]) - # for example, if using binary weights (k=1) with -a, +a, set transform = [a,2a]; if using binary weights (k=1) with a, 0, set transform = [0,-a]; - self.shift = nn.Parameter( - torch.Tensor(1).fill_( - 0.0 if transform[0] is None else transform[0] - ), - requires_grad=not fixed_transform[0], - ) - self.scale = nn.Parameter( - torch.Tensor(1).fill_( - 1.0 if transform[1] is None else transform[1] - ), - requires_grad=not fixed_transform[1], - ) - for i in range(-int(pow(2, bitwidth - 1)), int(pow(2, bitwidth - 1))): - self.weight[ - torch.logical_and( - self.weight > left_bound, self.weight <= right_bound - ) - ] = i - left_bound = right_bound - right_bound += least_step - - self.weight.requires_grad = not fixed_weight - - def forward(self, x): - subnet = GetSubnet.apply( - self.scores, - torch.zeros_like(self.scores), - torch.ones_like(self.scores), - self.sparsity, - ) - - if self.tile_size != 1: - for i, k in enumerate(self.weight.shape): - # if k == 1: continue - subnet = subnet.repeat_interleave(self.tile_size, dim=i) - subnet = torch.narrow(subnet, i, 0, k) - - w = (self.weight * self.scale + self.shift) * subnet - return F.conv2d( - x, w, self.bias, self.stride, self.padding, self.dilation, self.groups - ) - - -def apply_supermask( - model, - linear_sparsity=0.0, - linear_sp_tilesize=1, - conv1x1_sparsity=0.0, - conv1x1_sp_tilesize=1, - conv_sparsity=0.0, - conv_sp_tilesize=1, - skip_last_layer_sparsity=False, - skip_first_transformer_sparsity=False, - device="cuda", - verbose=False, -): - sparsified_modules = {} - - for n, m in model.named_modules(): - # check conditions for skipping sparsity - if skip_last_layer_sparsity and n == "heads.head": - continue - if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in n: - continue - - # convert 1x1 convolutions - if ( - conv1x1_sparsity != 0.0 - and isinstance(m, torch.nn.Conv2d) - and m.kernel_size == (1, 1) - ): - new_m = SupermaskConv2d( - conv1x1_sparsity, - False, - False, - None, - None, - None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv1x1_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - # convert all other convolutions (not tested!) - if conv_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d): - new_m = SupermaskConv2d( - conv_sparsity, - False, - False, - None, - None, - None, - m.in_channels, - m.out_channels, - m.kernel_size, - stride=m.stride, - padding=m.padding, - dilation=m.dilation, - groups=m.groups, - bias=m.bias is not None, - padding_mode=m.padding_mode, - device=device, - tile_size=conv_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear): - new_m = SupermaskLinear( - linear_sparsity, - False, - False, - None, - None, - None, - m.in_features, - m.out_features, - bias=m.bias is not None, - device=device, - tile_size=linear_sp_tilesize, - ) - new_m.weight.data.copy_(m.weight.data) - if m.bias is not None: - new_m.bias.data.copy_(m.bias.data) - sparsified_modules[n] = new_m - continue - - # add modules to model - for k, v in sparsified_modules.items(): - sm_name, ch_name = k.rsplit(".", 1) - sm = model.get_submodule(sm_name) - sm.add_module(ch_name, v) - - if verbose: - print( - f'sparsified module "{k}" with sparsity={v.sparsity}, tile size={v.tile_size}' - ) - - return model diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index 77ccd2c00b..c13bb4209c 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -13,11 +13,13 @@ semi_sparse_weight, sparsify_, ) +from .supermask import SupermaskLinear from .utils import PerChannelNormObserver # noqa: F403 from .wanda import WandaSparsifier # noqa: F403 __all__ = [ "WandaSparsifier", + "SupermaskLinear", "PerChannelNormObserver", "apply_fake_sparsity", "sparsify_", diff --git a/torchao/sparsity/prototype/superblock/supermask.py b/torchao/sparsity/prototype/superblock/supermask.py index f502d1f2ad..97d0b36c79 100644 --- a/torchao/sparsity/prototype/superblock/supermask.py +++ b/torchao/sparsity/prototype/superblock/supermask.py @@ -1,11 +1,7 @@ -from torchao.prototype.sparsity.superblock.supermask import ( - GetSubnet, - SupermaskConv2d, +from torchao.sparsity.supermask import ( SupermaskLinear, ) __all__ = [ - "GetSubnet", - "SupermaskConv2d", "SupermaskLinear", ] diff --git a/torchao/sparsity/supermask.py b/torchao/sparsity/supermask.py new file mode 100644 index 0000000000..a04b824428 --- /dev/null +++ b/torchao/sparsity/supermask.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +SCORES_MIN = None +SCORES_MAX = 9e9 + + +def percentile(t, q): + """Return the value that is larger than q% of t""" + k = 1 + round(0.01 * float(q) * (t.numel() - 1)) + return t.view(-1).kthvalue(k).values + + +class GetSubnet(torch.autograd.Function): + """Supermask STE function""" + + @staticmethod + def forward(ctx, scores, zeros, ones, sparsity): + clamped_scores = scores.clamp(min=SCORES_MIN, max=SCORES_MAX) + k_val = percentile(clamped_scores, sparsity * 100) + return torch.where( + clamped_scores < k_val, zeros.to(scores.device), ones.to(scores.device) + ) + + @staticmethod + def backward(ctx, g): + return g, None, None, None + + +class ApplyMask(torch.autograd.Function): + """Supermask STE function""" + + @staticmethod + def forward(ctx, weight, scores): + return weight * scores + + @staticmethod + def backward(ctx, grad_output): + grad_weight = grad_scores = None + if ctx.needs_input_grad[0]: + grad_weight = grad_output + if ctx.needs_input_grad[1]: + grad_scores = grad_output + return grad_weight, grad_scores + + +class SupermaskLinear(nn.Linear): + """Supermask class for Linear layer""" + + def __init__( + self, sparsity_level, blocksize, fixed_mask, fixed_weight, *args, **kwargs + ): + super(SupermaskLinear, self).__init__(*args, **kwargs) + # calculate the maximum sparsity given blocksize for the layer + max_sparsity_level = 1 - ( + 1 / math.prod([math.ceil(k / blocksize) for k in self.weight.size()]) + ) + self.sparsity_level = sparsity_level + if self.sparsity_level > max_sparsity_level: + print( + f"reducing sparsity from {self.sparsity} to {max_sparsity_level}", + f"(maximum sparsity for layer with shape {self.weight.size()} and tile size {blocksize})", + ) + self.sparsity_level = max_sparsity_level + self.blocksize = blocksize + self.sparsify_weights = False + self.scores = nn.Parameter( + torch.empty( + [max(1, int(math.ceil(wn / blocksize))) for wn in self.weight.size()] + ), + requires_grad=not fixed_mask, + ) + nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5)) + + # NOTE: the previous implementation of Supermask supported quantizing the weights, this has been removed. + + self.weight.requires_grad = not fixed_weight + + def get_mask(self): + subnet = GetSubnet.apply( + self.scores, + torch.zeros_like(self.scores), + torch.ones_like(self.scores), + self.sparsity_level, + ) + + if self.blocksize != 1: + for i, k in enumerate(self.weight.shape): + subnet = subnet.repeat_interleave(self.blocksize, dim=i) + subnet = torch.narrow(subnet, i, 0, k) + + return subnet + + def forward(self, x): + subnet = self.get_mask() + w = ApplyMask.apply(self.weight, subnet) + return F.linear(x, w, self.bias) + + @classmethod + def from_linear( + cls, + linear, + sparsity_level=0.0, + blocksize=1, + ): + """ + Main entrypoint for creating a SupermaskLinear from a Linear layer. + """ + assert isinstance(linear, torch.nn.Linear) + + supermask_linear = SupermaskLinear( + sparsity_level, + blocksize, + False, + False, + linear.in_features, + linear.out_features, + bias=linear.bias is not None, + ).to(device=linear.weight.device, dtype=linear.weight.dtype) + supermask_linear.weight.data.copy_(linear.weight.data) + if linear.bias is not None: + supermask_linear.bias.data.copy_(linear.bias.data) + return supermask_linear + + @classmethod + def to_linear(cls, supermask_linear): + """ + Convert a SupermaskLinear to a Linear layer. + Replaces the old sparsify_offline() function. + """ + self = supermask_linear + + linear = torch.nn.Linear( + self.in_features, + self.out_features, + bias=self.bias is not None, + ).to(device=self.weight.device, dtype=self.weight.dtype) + + mask = self.get_mask() + linear.weight.data.copy_(self.weight * mask) + if self.bias is not None: + linear.bias.data.copy_(self.bias.data) + return linear From c59561a769d205d65444b4340b5b8d13697b3c53 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Tue, 18 Feb 2025 19:56:18 -0800 Subject: [PATCH 091/115] SAM2: Update README.md (#1735) Update README.md --- examples/sam2_amg_server/README.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/examples/sam2_amg_server/README.md b/examples/sam2_amg_server/README.md index c09b012c26..2a35ad9fe1 100644 --- a/examples/sam2_amg_server/README.md +++ b/examples/sam2_amg_server/README.md @@ -1,3 +1,29 @@ +# Reproducing experiments locally + +You can simply run `python reproduce_experiments.py ` + +`image_paths_file` needs to be a flat list of paths to images, for example + +``` +/home/$USER/data/sav_val/JPEGImages_24fps/sav_044979/00349.jpg +/home/$USER/data/sav_val/JPEGImages_24fps/sav_006751/00204.jpg +/home/$USER/data/sav_val/JPEGImages_24fps/sav_053118/00239.jpg +/home/$USER/data/sav_val/JPEGImages_24fps/sav_053391/00517.jpg +/home/$USER/data/sav_val/JPEGImages_24fps/sav_018487/00001.jpg +/home/$USER/data/sav_val/JPEGImages_24fps/sav_028552/00153.jpg +/home/$USER/data/sav_val/JPEGImages_24fps/sav_013729/00103.jpg +/home/$USER/data/sav_val/JPEGImages_24fps/sav_014662/00339.jpg +``` + +or whichever other files you'd like to use for study. For example you may consider the Segment Anything Video (SA-V) [Dataset](https://github.com/facebookresearch/sam2/tree/main/sav_dataset#download-the-dataset). + +The experimental results will then be saved under `output_folder` in result.csv + +# Reproducing experiments on Modal + +For this you can run `modal_experiments.sh` after, but you'll want to experiments locally first to produce the meta annotations and exported ahead-of-time compiled binaries. + +# Using the server locally ## Example curl command ``` curl -X POST http://127.0.0.1:5000/upload -F 'image=@/path/to/file.jpg' --output path/to/output.png From 7fc8ad40df487b39010e357cd3e75f4a300239e8 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 18 Feb 2025 19:58:08 -0800 Subject: [PATCH 092/115] float8 training: clean up recipe names (#1730) Update [ghstack-poisoned] --- benchmarks/float8/float8_roofline.py | 4 ++-- test/float8/test_base.py | 4 ++-- test/float8/test_compile.py | 4 ++-- test/float8/test_dtensor.py | 2 +- test/float8/test_numerics_integration.py | 4 ++-- torchao/float8/config.py | 12 ++++++------ 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 2b3f631d8c..9bd4206d76 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -349,7 +349,7 @@ def run( # get the float8 dynamic axiswise scaling gpu kernel time torch._dynamo.reset() - config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE) + config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE) m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config) m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs) fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x) @@ -358,7 +358,7 @@ def run( # TODO(future PR): enable below once basic performance issues # are fixed # torch._dynamo.reset() - # config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP) + # config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE_WITH_GW_HP) # m_fp8_lw = convert_to_float8_training(m_orig, config=config) # m_fp8_lw = torch.compile(m_fp8_lw) # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index b537c7ab9f..055b3f3054 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -420,8 +420,8 @@ def test_linear_from_config_params( @pytest.mark.parametrize( "recipe_name", [ - Float8LinearRecipeName.ALL_AXISWISE, - Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.ROWWISE, + Float8LinearRecipeName.ROWWISE_WITH_GW_HP, ], ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index d9c71f7395..83ec188192 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -218,8 +218,8 @@ def test_inductor_from_config_params( @pytest.mark.parametrize( "recipe_name", [ - Float8LinearRecipeName.ALL_AXISWISE, - Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.ROWWISE, + Float8LinearRecipeName.ROWWISE_WITH_GW_HP, ], ) @unittest.skipIf( diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index d0f34da0a9..d71e23b6b2 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -198,7 +198,7 @@ def _test_fp8_mlp_tensor_parallelism_base( device = mesh.device_type if rowwise: - config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_AXISWISE) + config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE) # hack around config being frozen # TODO(future PR): we should make this nicer at the config level object.__setattr__(config, "emulate", True) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 311964d831..e47d4310b4 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -198,8 +198,8 @@ def test_encoder_fw_bw_from_config_params( @pytest.mark.parametrize( "recipe_name", [ - Float8LinearRecipeName.ALL_AXISWISE, - Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP, + Float8LinearRecipeName.ROWWISE, + Float8LinearRecipeName.ROWWISE_WITH_GW_HP, ], ) @pytest.mark.skipif( diff --git a/torchao/float8/config.py b/torchao/float8/config.py index b971ff31b0..c1720ea70c 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -326,9 +326,9 @@ def __post_init__(self): # TODO(future PR): go through a round of design on this, and eventually expose # as a top level public API. class Float8LinearRecipeName(enum.Enum): - ALL_TENSORWISE = "all_tensorwise" - ALL_AXISWISE = "all_axiswise" - LW_AXISWISE_WITH_GW_HP = "lw_axiswise_with_gw_hp" + TENSORWISE = "tensorwise" + ROWWISE = "rowwise" + ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp" def recipe_name_to_linear_config( @@ -339,11 +339,11 @@ def recipe_name_to_linear_config( Output: a `Float8LinearConfig` configured to implement the recipe """ - if recipe_name is Float8LinearRecipeName.ALL_TENSORWISE: + if recipe_name is Float8LinearRecipeName.TENSORWISE: # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel return Float8LinearConfig() - elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE: + elif recipe_name is Float8LinearRecipeName.ROWWISE: # dynamic axiswise scaling with the CUTLASS rowwise kernel cc_i = CastConfig( scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype @@ -363,7 +363,7 @@ def recipe_name_to_linear_config( round_scales_to_power_of_2=True, ) - elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP: + elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP: # lw's recipe for a modification on all-axiswise: # # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 From c6c388b53dd1734bb2ce96b16f2680a3cb68feaa Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 18 Feb 2025 19:59:01 -0800 Subject: [PATCH 093/115] float8 training: make the "config from recipe" API polished (#1731) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- benchmarks/float8/float8_roofline.py | 5 +- benchmarks/float8/profile_linear_float8.py | 6 +- test/float8/test_base.py | 3 +- test/float8/test_compile.py | 3 +- test/float8/test_dtensor.py | 3 +- test/float8/test_numerics_integration.py | 3 +- torchao/float8/config.py | 169 +++++++++++---------- 7 files changed, 97 insertions(+), 95 deletions(-) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 9bd4206d76..684ed0af2a 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -63,7 +63,6 @@ ScalingType, convert_to_float8_training, ) -from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config from torchao.float8.roofline_utils import ( get_float8_mem_sympy, get_gemm_time_sympy, @@ -349,7 +348,7 @@ def run( # get the float8 dynamic axiswise scaling gpu kernel time torch._dynamo.reset() - config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE) + config = Float8LinearConfig.from_recipe_name("rowwise") m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config) m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs) fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x) @@ -358,7 +357,7 @@ def run( # TODO(future PR): enable below once basic performance issues # are fixed # torch._dynamo.reset() - # config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE_WITH_GW_HP) + # config = Float8LinearConfig.from_recipe_name("rowwise_with_gw_hp") # m_fp8_lw = convert_to_float8_training(m_orig, config=config) # m_fp8_lw = torch.compile(m_fp8_lw) # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 5045956954..687684d4e2 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -39,9 +39,8 @@ from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( - Float8LinearRecipeName, + Float8LinearConfig, ScalingType, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -311,8 +310,7 @@ def main( emulate=False, ) elif recipe_name is not None: - recipe_name = Float8LinearRecipeName(recipe_name) - config = recipe_name_to_linear_config(recipe_name) + config = Float8LinearConfig.from_recipe_name(recipe_name) scaling_repr = "_".join( [ diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 055b3f3054..156c8abe87 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -32,7 +32,6 @@ ScalingType, e4m3_dtype, e5m2_dtype, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -442,7 +441,7 @@ def test_linear_from_recipe( linear_dtype = torch.bfloat16 x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) - config = recipe_name_to_linear_config(recipe_name) + config = Float8LinearConfig.from_recipe_name(recipe_name) self._test_linear_impl( x, m_ref, diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 83ec188192..0c02db26a6 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -33,7 +33,6 @@ Float8LinearRecipeName, ScalingType, e4m3_dtype, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( @@ -227,7 +226,7 @@ def test_inductor_from_config_params( ) def test_inductor_from_recipe(recipe_name): torch._dynamo.reset() - config = recipe_name_to_linear_config(recipe_name) + config = Float8LinearConfig.from_recipe_name(recipe_name) fullgraph = True dtype = torch.bfloat16 _test_compile_base( diff --git a/test/float8/test_dtensor.py b/test/float8/test_dtensor.py index d71e23b6b2..886cc2a504 100644 --- a/test/float8/test_dtensor.py +++ b/test/float8/test_dtensor.py @@ -41,7 +41,6 @@ Float8LinearRecipeName, ScalingType, e4m3_dtype, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import convert_to_float8_training from torchao.float8.float8_scaling_utils import NoopFwToFloat8BwDynamic @@ -198,7 +197,7 @@ def _test_fp8_mlp_tensor_parallelism_base( device = mesh.device_type if rowwise: - config = recipe_name_to_linear_config(Float8LinearRecipeName.ROWWISE) + config = Float8LinearConfig.from_recipe_name(Float8LinearRecipeName.ROWWISE) # hack around config being frozen # TODO(future PR): we should make this nicer at the config level object.__setattr__(config, "emulate", True) diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index e47d4310b4..01e4cbb20d 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -28,7 +28,6 @@ Float8LinearConfig, Float8LinearRecipeName, ScalingType, - recipe_name_to_linear_config, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, @@ -210,7 +209,7 @@ def test_encoder_fw_bw_from_recipe( self, recipe_name: str, ): - config = recipe_name_to_linear_config(recipe_name) + config = Float8LinearConfig.from_recipe_name(recipe_name) self._test_impl(config) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index c1720ea70c..ab2d89a91f 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -7,7 +7,7 @@ import enum import logging from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch @@ -146,6 +146,32 @@ class Float8GemmConfig: use_fast_accum: bool = False +# Pre-made recipes for common configurations +class Float8LinearRecipeName(enum.Enum): + + # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel + TENSORWISE = "tensorwise" + + # dynamic rowwise scaling with the CUTLASS rowwise kernel + # * e4m3 for activations, weights, gradients + # * scales rounded (floor) to the nearest power of two for increased accuracy + ROWWISE = "rowwise" + + # lw's recipe for a modification on rowwise scaling: + # + # output_hp = input_fp8_rowwise_dim0 @ weight_t_rowwise_dim1 + # grad_input_hp = grad_output_fp8_rowwise_dim0 @ weight_fp8_tensorwise + # grad_weight_hp = input_t_hp @ grad_output_hp + # + # key characteristics: + # * increased accuracy for grad_weight + # * `input`, `weight` and `grad_output` now only need to be scaled + # rowwise across a single dim compared to vanilla rowwise, + # which is more amenable to fast kernels + # * the e4m3 dtype is used across the board, including for gradients + ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp" + + @dataclass(frozen=True) class Float8LinearConfig: """ @@ -321,86 +347,69 @@ def __post_init__(self): "Note: delayed and static scaling will be deprecated in a future release of torchao. Please see https://github.com/pytorch/ao/issues/1680 for more details." ) + @staticmethod + def from_recipe_name( + recipe_name: Union[Float8LinearRecipeName, str], + ) -> "Float8LinearConfig": + """ + Input: `Float8LinearRecipeName` value, or a string representing a `Float8LinearRecipeName` value + Output: a `Float8LinearConfig` configured to implement the specified recipe + """ + if type(recipe_name) == str: + valid_names = [n.value for n in Float8LinearRecipeName] + assert ( + recipe_name in valid_names + ), f"recipe_name {recipe_name} not in valid names {valid_names}" + recipe_name = Float8LinearRecipeName(recipe_name) -# Pre-made recipes for common configurations -# TODO(future PR): go through a round of design on this, and eventually expose -# as a top level public API. -class Float8LinearRecipeName(enum.Enum): - TENSORWISE = "tensorwise" - ROWWISE = "rowwise" - ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp" + if recipe_name is Float8LinearRecipeName.TENSORWISE: + return Float8LinearConfig() + + elif recipe_name is Float8LinearRecipeName.ROWWISE: + cc_i = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_w = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + # enable power of 2 scaling factors by default for row-wise scaling + round_scales_to_power_of_2=True, + ) -def recipe_name_to_linear_config( - recipe_name: Float8LinearRecipeName, -) -> Float8LinearConfig: - """ - Input: `Float8LinearRecipeName` value - Output: a `Float8LinearConfig` configured to implement the recipe - """ + elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP: - if recipe_name is Float8LinearRecipeName.TENSORWISE: - # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel - return Float8LinearConfig() - - elif recipe_name is Float8LinearRecipeName.ROWWISE: - # dynamic axiswise scaling with the CUTLASS rowwise kernel - cc_i = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype - ) - cc_w = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype - ) - cc_go = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype - ) - - return Float8LinearConfig( - cast_config_input=cc_i, - cast_config_weight=cc_w, - cast_config_grad_output=cc_go, - # enable power of 2 scaling factors by default for row-wise scaling - round_scales_to_power_of_2=True, - ) - - elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP: - # lw's recipe for a modification on all-axiswise: - # - # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 - # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise - # grad_weight_hp = input_t_hp @ grad_output_hp - # - # key characteristics: - # * increased accuracy for grad_weight - # * `input`, `weight` and `grad_output` now only need to be scaled - # axiswise across a single dim compared to vanilla all-axiswise, - # which is more amenable to fast kernels - # * the e4m3 dtype is used across the board, including for gradients - - # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 - cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) - - # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise - cc_go = CastConfig( - scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype - ) - cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) - - # grad_weight_hp = input_t_hp @ grad_output_hp - cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) - cc_go_gw = CastConfig( - scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype - ) - - return Float8LinearConfig( - cast_config_input=cc_i, - cast_config_weight=cc_w, - cast_config_grad_output=cc_go, - cast_config_input_for_grad_weight=cc_i_gw, - cast_config_weight_for_grad_input=cc_w_gi, - cast_config_grad_output_for_grad_weight=cc_go_gw, - ) - - else: - raise AssertionError(f"unknown recipe_name {recipe_name}") + # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 + cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) + + # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise + cc_go = CastConfig( + scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype + ) + cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) + + # grad_weight_hp = input_t_hp @ grad_output_hp + cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) + cc_go_gw = CastConfig( + scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype + ) + + return Float8LinearConfig( + cast_config_input=cc_i, + cast_config_weight=cc_w, + cast_config_grad_output=cc_go, + cast_config_input_for_grad_weight=cc_i_gw, + cast_config_weight_for_grad_input=cc_w_gi, + cast_config_grad_output_for_grad_weight=cc_go_gw, + ) + + else: + raise AssertionError(f"unknown recipe_name {recipe_name}") From ed16fe771a51d05e38a31c6fd2658aa4c7f35ca2 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 18 Feb 2025 19:59:53 -0800 Subject: [PATCH 094/115] float8 training: add README.md entry for rowwise scaling (#1733) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- torchao/float8/README.md | 55 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index ddc717f953..4dbc556d83 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -17,9 +17,9 @@ throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs. We provide three per-tensor scaling strategies: dynamic, delayed and static. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`). -## float8 linear with dynamic scaling for `input`, `weight` and `grad_output` +## float8 linear with dynamic tensorwise scaling -This is the most accurate recipe as every tensor is scaled dynamically. +This is the default recipe, with a good balance of performance and accuracy. ```python import torch @@ -63,6 +63,57 @@ for _ in range(10): optimizer.step() ``` +## float8 linear with rowwise scaling + +This is a more accurate recipe compared to tensorwise, with more granular scaling. + +:warning: The composability of float8 with rowwise scaling with Tensor Parallelism is WIP, please see https://github.com/pytorch/ao/issues/1732 for more details. + +```python +import torch +import torch.nn as nn +from torchao.float8 import convert_to_float8_training, Float8LinearConfig +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + +if not TORCH_VERSION_AT_LEAST_2_5: + raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") + +# create model and sample input +m = nn.Sequential( + nn.Linear(2048, 4096), + nn.Linear(4096, 128), +).bfloat16().cuda() +x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) +optimizer = torch.optim.SGD(m.parameters(), lr=0.1) + +# optional: filter modules from being eligible for float8 conversion +def module_filter_fn(mod: torch.nn.Module, fqn: str): + # don't convert the last module + if fqn == "1": + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + +# configure rowwise scaling +config = Float8LinearConfig.from_recipe_name("rowwise") + +# convert specified `torch.nn.Linear` modules to `Float8Linear` +convert_to_float8_training(m, config=config, module_filter_fn=module_filter_fn) + +# enable torch.compile for competitive performance +m = torch.compile(m) + +# toy training loop +for _ in range(10): + optimizer.zero_grad() + y = m(x) + y.sum().backward() + optimizer.step() +``` + ## float8 linear with delayed scaling :warning: We plan to deprecate delayed scaling in a future release, see https://github.com/pytorch/ao/issues/1680 for more details. From ceceea505d37a91a4489bca683f914c1d37ef084 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 18 Feb 2025 22:04:19 -0800 Subject: [PATCH 095/115] promote blocksparse from prototype, make it faster (#1734) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR promotes block sparsity from prototype in torchao. Chiefly, it ports over the triton addmm blocksparse kernels from core, and makes several performance improvements to them. All of the numbers reported below are for an H100, with blocksize=64 and sparsity_level=0.9. The default dense baseline is 134 tok/s 1) Adds padding support to the triton kernel for dense matrices with dimension < 16, like those we run into during decoding. (214 -> 218 tok/s) 2) Changes the default [num_stages](https://github.com/triton-lang/triton/discussions/512) parameter from 1 to 4. This has a large effect on performance, and it seemed like the default kernel autotuning either does not modify or deems this parameter to be unimportant for some reason. (218 -> 263 tok/s). 3) Adds an env_var, BSR_AUTOTUNE, that users can use if they want to do kernel autotuning on top of the default parameters. (263 -> 266 tok/s) This seems to matter more for bs=n compute bound workloads, where I see a reduction from 0.3855 to 0.3745s on bs=8192 prefill (roughly 3%) So in total we are seeing a **1.985x** speedup 🚀 I've also updated the documentation to not reference prototype - planning on updating the diagram in a subsequent PR. ### Testing I added a new test case for the padding inputs and moved the test file out of prototype. ``` python test/sparsity/test_sparse_api.py ``` --- .../test_sparse_api.py | 9 +- torchao/_models/llama/generate.py | 39 +- torchao/kernel/__init__.py | 2 + torchao/kernel/bsr_triton_ops.py | 667 ++++++++++++++++++ torchao/ops.py | 10 + torchao/sparsity/README.md | 4 +- torchao/sparsity/__init__.py | 2 + .../superblock => sparsity}/blocksparse.py | 141 +++- torchao/sparsity/sparse_api.py | 12 +- 9 files changed, 843 insertions(+), 43 deletions(-) rename test/{prototype => sparsity}/test_sparse_api.py (96%) create mode 100644 torchao/kernel/bsr_triton_ops.py rename torchao/{prototype/sparsity/superblock => sparsity}/blocksparse.py (63%) diff --git a/test/prototype/test_sparse_api.py b/test/sparsity/test_sparse_api.py similarity index 96% rename from test/prototype/test_sparse_api.py rename to test/sparsity/test_sparse_api.py index 31fb85ffde..558474714c 100644 --- a/test/prototype/test_sparse_api.py +++ b/test/sparsity/test_sparse_api.py @@ -132,8 +132,9 @@ class TestBlockSparseWeight(common_utils.TestCase): ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("compile", [True, False]) - def test_sparse(self, compile): - input = torch.rand((1024, 1024)).half().cuda() + @common_utils.parametrize("input_shape", [1, 1024]) + def test_sparse(self, compile, input_shape): + input = torch.rand((input_shape, 1024)).half().cuda() model = ( nn.Sequential( nn.Linear(1024, 2048), @@ -152,9 +153,7 @@ def test_sparse(self, compile): model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16) dense_result = model(input) - from torchao.prototype.sparsity.superblock.blocksparse import ( - block_sparse_weight, - ) + from torchao.sparsity import block_sparse_weight sparsify_(model, block_sparse_weight(blocksize=64)) # if compile: diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 69b0fb6e99..0958a5207c 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -793,9 +793,37 @@ def ffn_or_attn_only(mod, fqn): from torchao.sparsity import semi_sparse_weight, sparsify_ if "semi" in sparsity: - # TODO there is a bug here, need to fix + # Fixed sparsity level for 2:4 sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only) + if "bsr" in sparsity: + from torchao.sparsity import SupermaskLinear, block_sparse_weight + + # parse "bsr-0.9-64" + _, sparsity_level, blocksize = sparsity.split("-") + sparsity_level, blocksize = float(sparsity_level), int(blocksize) + sparsify_( + model, + lambda x: SupermaskLinear.from_linear( + x, + sparsity_level=sparsity_level, + blocksize=blocksize, + ), + filter_fn=ffn_only, + ) + print(model) + sparsify_( + model, + SupermaskLinear.to_linear, + filter_fn=ffn_only, + ) + print(model) + + # Accelerate with triton bsr kernels + sparsify_( + model, block_sparse_weight(blocksize=blocksize), filter_fn=ffn_only + ) + model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 if save: @@ -810,7 +838,10 @@ def ffn_or_attn_only(mod, fqn): print("Compiling Model") global decode_one_token, prefill decode_one_token = torch.compile( - decode_one_token, mode="reduce-overhead", fullgraph=True + decode_one_token, + mode="reduce-overhead", + fullgraph=True, + dynamic=True, ) if compile_prefill: @@ -849,7 +880,7 @@ def ffn_or_attn_only(mod, fqn): prompt = f"{B_INST} {prompt.strip()} {E_INST}" encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) - if interactive and i >= 0: + if interactive and i >= 0 and prefill_size is None: buffer = [] period_id = tokenizer.encode(".")[0] done_generating = False @@ -919,7 +950,7 @@ def callback(x): device_sync(device=device) # MKG t = time.perf_counter() - t0 - if not interactive and demo_summarize_prompt is None: + if not interactive and demo_summarize_prompt is None and prefill_size is None: tok_list = y[0].tolist() # truncate text after end of string token tokens = ( diff --git a/torchao/kernel/__init__.py b/torchao/kernel/__init__.py index 409da72601..ed5c64e31d 100644 --- a/torchao/kernel/__init__.py +++ b/torchao/kernel/__init__.py @@ -1,6 +1,8 @@ +from torchao.kernel.bsr_triton_ops import bsr_dense_addmm from torchao.kernel.intmm import int_scaled_matmul, safe_int_mm __all__ = [ + "bsr_dense_addmm", "safe_int_mm", "int_scaled_matmul", ] diff --git a/torchao/kernel/bsr_triton_ops.py b/torchao/kernel/bsr_triton_ops.py new file mode 100644 index 0000000000..2dcdead966 --- /dev/null +++ b/torchao/kernel/bsr_triton_ops.py @@ -0,0 +1,667 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import os +from typing import Optional + +import torch + +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 + +if TORCH_VERSION_AT_LEAST_2_4: + from torch._dynamo.utils import warn_once +else: + import warnings + + warn_once = warnings.warn +from torch.sparse._triton_ops import ( + broadcast_batch_dims, + launch_kernel, + prepare_inputs, + ptr_stride_extractor, + tile_to_blocksize, +) +from torch.sparse._triton_ops_meta import get_meta, minimize, update +from torch.utils._triton import has_triton + +AUTOTUNE = os.getenv("BSR_AUTOTUNE", False) + + +def tune_bsr_dense_addmm( + input, + bsr, + dense, + *, + beta=1, + alpha=1, + left_alpha=None, + right_alpha=None, + out=None, + store=False, + verbose=False, + force=False, + opname=None, +): + """Tune bsr_dense_addmm kernel parameters against the given inputs. + + When store is True, the tuning results will be stored in the + database of kernel parameters. + """ + import triton + + if opname is None: + opname = "bsr_dense_addmm" + + N = dense.shape[-1] + values = bsr.values() + crow_indices = bsr.crow_indices() + batch_ndim = crow_indices.dim() - 1 + M, K = bsr.shape[batch_ndim : batch_ndim + 2] + BM, BK = values.shape[batch_ndim + 1 : batch_ndim + 3] + + # Reference parameters is a set of parameters that leads to a + # successful kernel call and the corresponding timing is used as a + # reference for computing speedups. Avoid changing the reference + # parameters when possible. + reference_meta = dict( + GROUP_SIZE_ROW=1, num_stages=4, num_warps=4, SPLIT_N=max(N // BM, 1) + ) + + # Compute the key of parameters: + sparsity = round(1 - bsr._nnz() * BM * BK / (M * K), 2) + dtype = bsr.dtype + if out is None: + out_dtype = dtype + else: + out_dtype = out.dtype + if out_dtype is dtype: + version_dtype = dtype + else: + version_dtype = (dtype, out_dtype) + version = (0, version_dtype, sparsity) + key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1) + + # For tuning, for an initial state, use parameters from the + # database if available, otherwise, use the reference parameters. + initial_meta = get_meta(opname, key, version=version, exact=True) + if initial_meta is None: + may_skip_update = False + initial_meta = get_meta(opname, key, version=(0, dtype, 0.5), exact=True) + if initial_meta is None: + initial_meta = reference_meta + elif not force: + return initial_meta + else: + may_skip_update = True + + # The target function that is minimized in the tuning process: + def bench(meta, input=input, bsr=bsr, dense=dense, alpha=alpha, out=out): + def test_func(): + return bsr_dense_addmm( + input, + bsr, + dense, + beta=beta, + alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, + meta=meta, + out=out, + ) + + return triton.testing.do_bench(test_func, warmup=500, rep=100) + + # The step function that increments a specified meta parameter: + def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=BK): + # return next value in positive or negative direction, or + # input value if the step will result an invalid + # value. The input value is assumed to be valid. + is_log = name in {"SPLIT_N", "num_warps"} + min_value = dict(SPLIT_N=1, num_warps=1, num_stages=1, GROUP_SIZE_ROW=1)[name] + max_value = dict(SPLIT_N=max(N // BM, 1)).get(name) + value_step = dict(SPLIT_N=2, num_warps=2, num_stages=1, GROUP_SIZE_ROW=1)[name] + if is_log: + next_value = ( + value * value_step**direction + if direction > 0 + else value // (value_step ** abs(direction)) + ) + else: + next_value = value + value_step * direction + if min_value is not None: + next_value = max(next_value, min_value) + if max_value is not None: + next_value = min(next_value, max_value) + if name == "SPLIT_N" and N % next_value != 0: + return value + return next_value + + # Tune: + meta, speedup, timing, sensitivity_message = minimize( + bench, + initial_meta, + reference_meta, + step_meta_parameter, + max_step=2, + verbose=verbose, + ) + if verbose: + print(f"-> {sensitivity_message}, {speedup=:.1f} %, {timing=:.3f} ms") + + if store and not ( + may_skip_update and meta == initial_meta and initial_meta is not reference_meta + ): + device_name = torch.cuda.get_device_name() + update( + opname, + device_name, + version, + key, + tuple(meta[k] for k in sorted(meta)), + ) + + return meta + + +def bsr_dense_addmm_meta( + M, + K, + N, + Ms, + Ks, + beta, + alpha, + SPLIT_N=None, + GROUP_SIZE_ROW=None, + num_warps=None, + num_stages=None, + sparsity=None, + dtype=None, + out_dtype=None, + _version=0, + **extra, +): + # Specifying _version is useful for situations when one wants to + # discard existing triton kernel tuning results, say, in testing + # bsr_dense_addmm_meta functionality. + if dtype is None: + dtype = torch.float16 + if out_dtype is None: + out_dtype = dtype + if sparsity is None: + sparsity = 0.5 + if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}: + device_name = torch.cuda.get_device_name() + key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1) + if dtype is out_dtype: + version_dtype = dtype + else: + version_dtype = dtype, out_dtype + meta = get_meta( + "bsr_dense_addmm", + key, + device_name, + version=(_version, version_dtype, sparsity), + ) + if meta is None and sparsity != 0.5: + meta = get_meta( + "bsr_dense_addmm", + key, + device_name, + version=(_version, version_dtype, 0.5), + ) + if meta is None and dtype is not out_dtype: + meta = get_meta( + "bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5) + ) + if meta is None: + # find approximate meta such that N % SPLIT_N == 0. + matching_meta = get_meta( + "bsr_dense_addmm", + (*key[:2], "*", *key[3:]), + device_name, + version=(_version, version_dtype, 0.5), + ) + if matching_meta is None and dtype is not out_dtype: + matching_meta = get_meta( + "bsr_dense_addmm", + (*key[:2], "*", *key[3:]), + device_name, + version=(_version, dtype, 0.5), + ) + for mkey in sorted(matching_meta or {}): + meta_ = matching_meta[mkey] + n = mkey[2] + split_n = meta_["SPLIT_N"] + c = n // split_n + if N % c == 0 and n <= N: + meta = dict(meta_) + meta["SPLIT_N"] = N // c + if meta is not None: + meta.update(**extra) + return meta + else: + warn_once( + "bsr_dense_addmm uses non-optimal triton kernel parameters" + f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=} {out_dtype=}. " + "To find optimal triton kernel parameters, run with BSR_AUTOTUNE=1" + ) + + SPLIT_N = SPLIT_N or max(N // Ms, 1) + GROUP_SIZE_ROW = GROUP_SIZE_ROW or 4 + num_stages = num_stages or 4 + num_warps = num_warps or 4 + return dict( + SPLIT_N=SPLIT_N, + GROUP_SIZE_ROW=GROUP_SIZE_ROW, + num_stages=num_stages, + num_warps=num_warps, + **extra, + ) + + +def bsr_dense_addmm( + input: torch.Tensor, + bsr: torch.Tensor, + dense: torch.Tensor, + *, + beta=1, + alpha=1, + left_alpha: Optional[torch.Tensor] = None, + right_alpha: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + skip_checks: bool = False, + max_grid: Optional[tuple[Optional[int], Optional[int], Optional[int]]] = None, + meta: Optional[dict] = None, +): + """Compute + + out = beta * input + left_alpha.reshape(-1, 1) * (alpha * (bsr @ dense)) * right_alpha.reshape(1, -1) + + where left_alpha, right_alpha are (* + 1)-D tensors when + specified, otherwise, these are treated as tensors filled with + ones. + """ + f_name = "bsr_dense_addmm" + values = bsr.values() + crow_indices = bsr.crow_indices() + col_indices = bsr.col_indices() + batch_ndim = crow_indices.dim() - 1 + M, K = bsr.shape[batch_ndim : batch_ndim + 2] + blocksize = values.shape[batch_ndim + 1 : batch_ndim + 3] + N = dense.shape[-1] + + original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) + if out is None: + out = dense.new_empty(original_batch_dims_broadcasted + (M, N)) + + if bsr._nnz() == 0 or alpha == 0 or N == 0 or M == 0 or K == 0: + if beta == 0: + out.zero_() + else: + out.copy_(input) + if beta != 1: + out.mul_(beta) + return out + + if meta is None: + sparsity = round(1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K), 2) + if AUTOTUNE: + meta = tune_bsr_dense_addmm( + input, + bsr, + dense, + beta=beta, + alpha=alpha, + left_alpha=left_alpha, + right_alpha=right_alpha, + out=out, + store=True, + force=False, + verbose=True, + opname="bsr_dense_addmm", + ) + else: + meta = bsr_dense_addmm_meta( + M, + K, + N, + blocksize[0], + blocksize[1], + beta, + alpha, + sparsity=sparsity, + dtype=dense.dtype, + out_dtype=out.dtype, + ) + + left_alpha_is_one = False + right_alpha_is_one = False + if left_alpha is None: + left_alpha_is_one = True + left_alpha = dense.new_empty(()).expand( + *original_batch_dims_broadcasted, M, N + ) # not referenced + else: + left_alpha = left_alpha.view(*original_batch_dims_broadcasted, M, 1).expand( + *original_batch_dims_broadcasted, M, N + ) + + if right_alpha is None: + right_alpha_is_one = True + right_alpha = dense.new_empty(()).expand( + *original_batch_dims_broadcasted, M, N + ) # not referenced + else: + right_alpha = right_alpha.view(*original_batch_dims_broadcasted, 1, N).expand( + *original_batch_dims_broadcasted, M, N + ) + assert left_alpha.stride()[-1] == 0 + assert right_alpha.stride()[-2] == 0 + + out_backup = out + + ( + crow_indices, + col_indices, + values, + input, + dense, + left_alpha, + right_alpha, + out, + ) = prepare_inputs(bsr, input, dense, left_alpha, right_alpha, out) + + BM, BK = blocksize + SPLIT_N = meta.get("SPLIT_N", max(N // BM, 1)) + BN = N // SPLIT_N + + out_untiled = out + out = tile_to_blocksize(out, (BM, BN)) + dense = tile_to_blocksize(dense, (BK, BN)) + input = tile_to_blocksize(input, (BM, BN)) + left_alpha = tile_to_blocksize(left_alpha, (BM, BN)) + right_alpha = tile_to_blocksize(right_alpha, (BM, BN)) + + # tl.dot supports float16, float32, int32 as accumulator types. + dot_out_dtype = { + torch.float16: tl.float32, + torch.bfloat16: tl.float32, + torch.float32: tl.float64, + torch.float64: tl.float64, + torch.int8: tl.int32, + torch.int32: tl.int32, + }[out.dtype] + + n_batches = dense.size(0) + n_block_rows = crow_indices.size(-1) - 1 + n_block_cols = dense.size(-3) + + full_grid = (n_batches, n_block_cols, n_block_rows) + if max_grid is not None: + grid_blocks = tuple(max_grid[:3][::-1]) + (None,) * (3 - len(max_grid[:3])) + else: + grid_blocks = None + + tensor_dims_map = { + values: (0, None, None), + crow_indices: (0, None, -1), + col_indices: (0, None, None), + input: (0, -3, -4), + dense: (0, -3, None), + left_alpha: (0, -3, -4), + right_alpha: (0, -3, -4), + out: (0, -3, -4), + } + + assert alpha != 0 + + def kernel(grid, *sliced_tensors): + _bsr_strided_addmm_kernel[grid]( + *ptr_stride_extractor(*sliced_tensors), + beta, + alpha, + beta_is_one=beta == 1, + beta_is_nonzero=beta != 0, + alpha_is_one=alpha == 1, + left_alpha_is_one=left_alpha_is_one, + right_alpha_is_one=right_alpha_is_one, + BLOCKSIZE_ROW=BM, + BLOCKSIZE_INNER=BK, + BLOCKSIZE_COL=BN, + allow_tf32=dot_out_dtype == tl.float32, + acc_dtype=dot_out_dtype, + **meta, + ) + + launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks) + + if out.data_ptr() != out_backup.data_ptr(): + # prepare_inputs has made a copy of out, copy its content back + # to out_backup: + out_backup.copy_(out_untiled.view(out_backup.shape)) + + return out_backup + + +if has_triton(): + import triton + import triton.language as tl + + @triton.jit + def _bsr_strided_addmm_kernel( + # values prologue + values_ptr, + values_batch_stride, + values_nnz_stride, + values_row_block_stride, + values_col_block_stride, + # values epilogue + # crow_indices prologue + crow_indices_ptr, + crow_indices_batch_stride, + crow_indices_stride, + # crow_indices epilogue + # col_indices prologue + col_indices_ptr, + col_indices_batch_stride, + col_indices_stride, + # col_indices epilogue + # input prologue + input_ptr, + input_batch_stride, + input_tiled_row_stride, + input_tiled_col_stride, + input_row_block_stride, + input_col_block_stride, + # input epilogue + # dense prologue + dense_ptr, + dense_batch_stride, + dense_tiled_row_stride, + dense_tiled_col_stride, + dense_row_block_stride, + dense_col_block_stride, + # dense epilogue + # left_alpha prologue + left_alpha_ptr, + left_alpha_batch_stride, + left_alpha_tiled_row_stride, + left_alpha_tiled_col_stride: tl.constexpr, + left_alpha_row_block_stride, + left_alpha_col_block_stride: tl.constexpr, + # left_alpha epilogue + # right_alpha prologue + right_alpha_ptr, + right_alpha_batch_stride, + right_alpha_tiled_row_stride: tl.constexpr, + right_alpha_tiled_col_stride, + right_alpha_row_block_stride: tl.constexpr, + right_alpha_col_block_stride, + # right_alpha epilogue + # output prologue + output_ptr, + output_batch_stride, + output_tiled_row_stride, + output_tiled_col_stride, + output_row_block_stride, + output_col_block_stride, + # output epilogue + beta, + alpha, + beta_is_one: tl.constexpr, + beta_is_nonzero: tl.constexpr, + alpha_is_one: tl.constexpr, + left_alpha_is_one: tl.constexpr, + right_alpha_is_one: tl.constexpr, + BLOCKSIZE_ROW: tl.constexpr, + BLOCKSIZE_COL: tl.constexpr, + BLOCKSIZE_INNER: tl.constexpr, + acc_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + GROUP_SIZE_ROW: tl.constexpr, + SPLIT_N: tl.constexpr, + ): + # left/right_alpha tensors are originally (* + 1)-dimensional + assert left_alpha_tiled_col_stride == 0 + assert left_alpha_col_block_stride == 0 + assert right_alpha_tiled_row_stride == 0 + assert right_alpha_row_block_stride == 0 + + batch_pid = tl.program_id(axis=2) + row_block_pid = tl.program_id(axis=0) + col_block_pid = tl.program_id(axis=1) + n_block_rows = tl.num_programs(axis=0) + n_block_cols = tl.num_programs(axis=1) + + row_block_pid, col_block_pid = tl.swizzle2d( + row_block_pid, col_block_pid, n_block_rows, n_block_cols, GROUP_SIZE_ROW + ) + + crow_indices_offset_ptr = ( + crow_indices_ptr + + crow_indices_batch_stride * batch_pid + + crow_indices_stride * row_block_pid + ) + nnz_offset = tl.load(crow_indices_offset_ptr) + nnz_offset_next = tl.load(crow_indices_offset_ptr + crow_indices_stride) + + # Compute nnz for the row with number row_block_pid. + row_nnz = nnz_offset_next - nnz_offset + + row_block_arange = tl.arange(0, BLOCKSIZE_ROW) + inner_block_arange = tl.arange(0, BLOCKSIZE_INNER) + + if BLOCKSIZE_COL < 16 or BLOCKSIZE_COL % 16 != 0: + PADDED_BLOCKSIZE_COL: tl.constexpr = 16 + else: + PADDED_BLOCKSIZE_COL: tl.constexpr = BLOCKSIZE_COL + + col_block_arange = tl.arange(0, PADDED_BLOCKSIZE_COL) + + # Pointers are set to the first block of the current row. + values_block_ptrs = ( + values_ptr + + values_batch_stride * batch_pid + + values_nnz_stride * nnz_offset + + values_row_block_stride * row_block_arange[:, None] + + values_col_block_stride * inner_block_arange[None, :] + ) + + # NOTE: dense is advanced into all dimensions but the tiled row one. + # That will be advanced in the loop according to values in col_indices. + dense_block_ptrs = ( + dense_ptr + + dense_batch_stride * batch_pid + + dense_tiled_col_stride * col_block_pid + + dense_row_block_stride * inner_block_arange[:, None] + + dense_col_block_stride * col_block_arange[None, :] + ) + + # Pointers are set to exact write-to locations + output_ptrs = ( + output_ptr + + output_batch_stride * batch_pid + + output_tiled_row_stride * row_block_pid + + output_tiled_col_stride * col_block_pid + + output_row_block_stride * row_block_arange[:, None] + + output_col_block_stride * col_block_arange[None, :] + ) + + # Set pointer to the first nonzero element in the current row + col_index_nnz_ptr = ( + col_indices_ptr + + col_indices_batch_stride * batch_pid + + col_indices_stride * nnz_offset + ) + + output_acc_block = tl.zeros( + (BLOCKSIZE_ROW, PADDED_BLOCKSIZE_COL), dtype=acc_dtype + ) + for _ in range(row_nnz): + values_block = tl.load(values_block_ptrs) + + # find which row of dense needs to get loaded + # for multiplication with values_block. + dense_row_idx = tl.load(col_index_nnz_ptr) + dense_block = tl.load( + dense_block_ptrs + dense_tiled_row_stride * dense_row_idx, + mask=col_block_arange[None, :] < BLOCKSIZE_COL, + ) + + # do block mm + output_acc_block += tl.dot( + values_block, dense_block, allow_tf32=allow_tf32, out_dtype=acc_dtype + ) + + # move val/col_index ptrs to the next block in the row + values_block_ptrs += values_nnz_stride + col_index_nnz_ptr += col_indices_stride + + if not alpha_is_one: + output_acc_block *= alpha + + if not left_alpha_is_one: + left_alpha_ptrs = ( + left_alpha_ptr + + left_alpha_batch_stride * batch_pid + + left_alpha_tiled_row_stride * row_block_pid + + left_alpha_tiled_col_stride * col_block_pid + + left_alpha_row_block_stride * row_block_arange[:, None] + + left_alpha_col_block_stride * col_block_arange[None, :] + ) + output_acc_block *= tl.load(left_alpha_ptrs) + + if not right_alpha_is_one: + right_alpha_ptrs = ( + right_alpha_ptr + + right_alpha_batch_stride * batch_pid + + right_alpha_tiled_row_stride * row_block_pid + + right_alpha_tiled_col_stride * col_block_pid + + right_alpha_row_block_stride * row_block_arange[:, None] + + right_alpha_col_block_stride * col_block_arange[None, :] + ) + output_acc_block *= tl.load(right_alpha_ptrs) + + if beta_is_nonzero: + input_ptrs = ( + input_ptr + + input_batch_stride * batch_pid + + input_tiled_row_stride * row_block_pid + + input_tiled_col_stride * col_block_pid + + input_row_block_stride * row_block_arange[:, None] + + input_col_block_stride * col_block_arange[None, :] + ) + if beta_is_one: + output_acc_block += tl.load(input_ptrs) + else: + output_acc_block += beta * tl.load(input_ptrs) + + # write back the result + tl.store( + output_ptrs, + output_acc_block.to(output_ptr.dtype.element_ty), + mask=col_block_arange[None, :] < BLOCKSIZE_COL, + ) + +else: + _bsr_strided_addmm_kernel = None # type: ignore[assignment] diff --git a/torchao/ops.py b/torchao/ops.py index 56980b17f1..bba2a054fc 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -39,6 +39,16 @@ def decorator(func): return decorator +def register_custom_op_impl(name): + def decorator(func): + if TORCH_VERSION_AT_LEAST_2_4: + return torch.library.custom_op(f"{name}", mutates_args=())(func) + else: + return torch.library.impl(f"{name}", "CUDA")(func) + + return decorator + + def quant_llm_linear( EXPONENT: int, MANTISSA: int, diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index be7fa8979b..b689a3adf4 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -85,12 +85,12 @@ model = model.cuda() sparsify_(model, semi_sparse_weight()) ``` -### Block sparsity (prototype) +### Block sparsity We offer prototype support for accelerating block sparsity with our triton kernels for bfloat16/float16 workloads. ```py from torchao.sparsity.sparse_api import sparsify_ -from torchao.prototype.sparsity.superblock.blocksparse import block_sparse_weight +from torchao.sparsity import block_sparse_weight model = model.cuda() sparsify_(model, block_sparse_weight()) diff --git a/torchao/sparsity/__init__.py b/torchao/sparsity/__init__.py index c13bb4209c..e7f98332be 100644 --- a/torchao/sparsity/__init__.py +++ b/torchao/sparsity/__init__.py @@ -10,6 +10,7 @@ from .sparse_api import ( apply_fake_sparsity, + block_sparse_weight, semi_sparse_weight, sparsify_, ) @@ -24,5 +25,6 @@ "apply_fake_sparsity", "sparsify_", "semi_sparse_weight", + "block_sparse_weight", "int8_dynamic_activation_int8_semi_sparse_weight", ] diff --git a/torchao/prototype/sparsity/superblock/blocksparse.py b/torchao/sparsity/blocksparse.py similarity index 63% rename from torchao/prototype/sparsity/superblock/blocksparse.py rename to torchao/sparsity/blocksparse.py index b5e8432949..f0da181339 100644 --- a/torchao/prototype/sparsity/superblock/blocksparse.py +++ b/torchao/sparsity/blocksparse.py @@ -1,18 +1,17 @@ -from functools import partial from typing import List, Optional, Tuple import torch -from torch.sparse._triton_ops import broadcast_batch_dims, bsr_dense_addmm from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.quantization.quant_api import _get_linear_subclass_inserter +from torchao.kernel.bsr_triton_ops import broadcast_batch_dims, bsr_dense_addmm +from torchao.ops import register_custom_op, register_custom_op_impl from torchao.utils import TorchAOBaseTensor aten = torch.ops.aten # quantization support -@torch.library.custom_op("blocksparse::bsr_to_dense", mutates_args=()) +@register_custom_op_impl("blocksparse::bsr_to_dense") def bsr_to_dense( crow_indices: torch.Tensor, col_indices: torch.Tensor, @@ -25,7 +24,7 @@ def bsr_to_dense( ).to_dense() -@torch.library.register_fake("blocksparse::bsr_to_dense") +@register_custom_op("blocksparse::bsr_to_dense") def bsr_to_dense_abstract( crow_indices: torch.Tensor, col_indices: torch.Tensor, @@ -36,7 +35,7 @@ def bsr_to_dense_abstract( return torch.empty((M, K), dtype=values.dtype, device=values.device) -@torch.library.custom_op("blocksparse::int_addmm", mutates_args=()) +@register_custom_op_impl("blocksparse::int_addmm") def blocksparse_int_addmm( crow_indices: torch.Tensor, col_indices: torch.Tensor, @@ -66,7 +65,7 @@ def blocksparse_int_addmm( ).t() -@torch.library.register_fake("blocksparse::int_addmm") +@register_custom_op("blocksparse::int_addmm") def blocksparse_int_addmm_abstract( crow_indices: torch.Tensor, col_indices: torch.Tensor, @@ -81,10 +80,9 @@ def blocksparse_int_addmm_abstract( return torch.empty((M, N), dtype=torch.bfloat16, device=A.device).t() -# bsr wrapper custom op -@torch.library.custom_op("blocksparse::linear", mutates_args=()) -def blocksparse_linear( - A: torch.Tensor, +@register_custom_op_impl("blocksparse::addmm") +def blocksparse_addmm( + x_padded: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, @@ -92,13 +90,24 @@ def blocksparse_linear( K: int, bias: torch.Tensor, ) -> torch.Tensor: - weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) - return torch.nn.functional.linear(A, weight_bsr, bias) + assert bias is None + bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) + N_padded = x_padded.shape[1] + out = x_padded.new_empty((M, N_padded)) + bsr_dense_addmm( + out, + bsr, + x_padded, + alpha=1, + beta=0, + out=out, + ) + return out -@torch.library.register_fake("blocksparse::linear") -def blocksparse_linear_abstract( - A: torch.Tensor, +@register_custom_op("blocksparse::addmm") +def blocksparse_addmm_abstract( + x_padded: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, @@ -106,8 +115,8 @@ def blocksparse_linear_abstract( K: int, bias: torch.Tensor, ) -> torch.Tensor: - new_shape = A.shape[:-1] + (M,) - return torch.empty(new_shape, dtype=A.dtype, device=A.device) + N_padded = x_padded.shape[1] + return x_padded.new_empty((M, N_padded)) # Subclass definition @@ -115,6 +124,7 @@ class BlockSparseTensor(TorchAOBaseTensor): bsr_crow_indices: Optional[torch.Tensor] bsr_col_indices: Optional[torch.Tensor] bsr_values: Optional[torch.Tensor] + blocksize: int __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values"] @@ -122,6 +132,7 @@ class BlockSparseTensor(TorchAOBaseTensor): def __new__( # noqa: PYI034 cls, shape: torch.Size, + blocksize: int, bsr_crow_indices: Optional[torch.Tensor], bsr_col_indices: Optional[torch.Tensor], bsr_values: Optional[torch.Tensor], @@ -141,33 +152,36 @@ def __new__( # noqa: PYI034 "requires_grad": requires_grad, } tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + tensor.blocksize = blocksize tensor.bsr_crow_indices = bsr_crow_indices - tensor.bsr_col_indices = bsr_col_indices tensor.bsr_values = bsr_values + tensor.bsr_col_indices = bsr_col_indices return tensor def __repr__(self) -> str: # type: ignore[override] assert hasattr(self, "shape") return f"{self.__class__.__name__}(shape={self.shape})" - def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool]]: + def __tensor_flatten__(self) -> Tuple[List[str], Tuple[torch.Size, bool, int]]: inner_tensors = list( filter(lambda x: getattr(self, x) is not None, self.__slots__) ) - tensor_meta = (self.shape, self.requires_grad) + tensor_meta = (self.shape, self.requires_grad, self.blocksize) return inner_tensors, tensor_meta @classmethod def __tensor_unflatten__( cls, inner_tensors, - tensor_meta: Tuple[torch.Size, bool], + tensor_meta: Tuple[torch.Size, bool, int], outer_size, outer_stride, ) -> torch.Tensor: - shape, requires_grad = tensor_meta + shape, requires_grad, blocksize = tensor_meta + # print("unflatten", outer_size, outer_stride) return cls( shape=shape, + blocksize=blocksize, bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), bsr_col_indices=inner_tensors.get("bsr_col_indices", None), bsr_values=inner_tensors.get("bsr_values", None), @@ -177,8 +191,10 @@ def __tensor_unflatten__( @classmethod def from_dense(cls, dense_tensor, blocksize): bsr_tensor = dense_tensor.to_sparse_bsr(blocksize) + # bsr_tensor_t = dense_tensor.t().contiguous().to_sparse_bsr(blocksize) return cls( shape=dense_tensor.shape, + blocksize=blocksize, bsr_crow_indices=bsr_tensor.crow_indices(), bsr_col_indices=bsr_tensor.col_indices(), bsr_values=bsr_tensor.values(), @@ -188,6 +204,7 @@ def from_dense(cls, dense_tensor, blocksize): def apply_fn_to_shard(self, func): return BlockSparseTensor( shape=self.shape, + blocksize=self.blocksize, bsr_crow_indices=func(self.bsr_crow_indices), bsr_col_indices=func(self.bsr_col_indices), bsr_values=func(self.bsr_values), @@ -206,6 +223,59 @@ def block_sparse_detach(func, types, args, kwargs): ) +@implements(aten.unsqueeze.default) +def block_sparse_unsqueeze(func, types, args, kwargs): + assert len(args) == 2 + assert len(kwargs) == 0 + assert args[-1] == 2 + bsr = args[0] + assert bsr.dim() == 2 + assert not bsr.requires_grad + return BlockSparseTensor( + bsr.shape + (1,), + bsr.blocksize, + bsr.crow_indices(), + bsr.col_indices(), + bsr.values().unsqueeze(-1), + requires_grad=False, + ) + + +@implements(aten.mul.Tensor) +def block_sparse_mul(func, types, args, kwargs): + assert len(args) == 2 + assert len(kwargs) == 0 + bsr, t = args + + def my_mul(bsr, t): + assert isinstance(bsr, BlockSparseTensor) + assert isinstance(t, torch.Tensor) + assert bsr.dim() == 3 + assert t.dim() == 3 + assert not bsr.requires_grad + assert t.size(0) == 1 + t_blocked = t.view(t.size(0), t.size(1) // bsr.blocksize, bsr.blocksize, 1) + masked_t = t_blocked.transpose(0, 1).index_select(0, bsr.col_indices()) + new_values = bsr.values() * masked_t + return BlockSparseTensor( + bsr.shape, bsr.blocksize, bsr.crow_indices(), bsr.col_indices(), new_values + ) + + if isinstance(bsr, torch.Tensor) and isinstance(t, BlockSparseTensor): + return my_mul(t, bsr) + return my_mul(bsr, t) + + +@implements(aten.sum.dim_IntList) +def block_sparse_sum(func, types, args, kwargs): + bsr, dim = args + assert type(dim) == list + assert len(dim) == 1 + dim = dim[0] + assert dim == 1 + return torch.ops.blocksparse.sum(bsr.values(), bsr.crow_indices(), bsr.shape[0]) + + @implements(aten.values.default) def block_sparse_values(func, types, args, kwargs): return args[0].bsr_values.detach() @@ -228,13 +298,22 @@ def block_sparse__nnz(func, types, args, kwargs): @implements(torch.nn.functional.linear) def block_sparse_linear(func, types, args, kwargs): - x, w, bias = args - return torch.ops.blocksparse.linear( - x, w.crow_indices(), w.col_indices(), w.values(), w.shape[0], w.shape[1], bias + x_orig, w, bias = args + x = x_orig.reshape(-1, x_orig.size(-1)).t() + M = w.shape[0] + K = w.shape[1] + + out = torch.ops.blocksparse.addmm( + x, + w.crow_indices(), + w.col_indices(), + w.values(), + M, + K, + None, ) + out_orig = out.t() + if bias is None: + return out_orig - -def block_sparse_weight(blocksize=64): - return _get_linear_subclass_inserter( - partial(BlockSparseTensor.from_dense, blocksize=blocksize) - ) + return out_orig + bias diff --git a/torchao/sparsity/sparse_api.py b/torchao/sparsity/sparse_api.py index eb31cba619..9e9611e0ad 100644 --- a/torchao/sparsity/sparse_api.py +++ b/torchao/sparsity/sparse_api.py @@ -1,14 +1,18 @@ +from functools import partial from typing import Callable, Optional import torch -from torch.ao.pruning import WeightNormSparsifier from torch.sparse import to_sparse_semi_structured +from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import ( + WeightNormSparsifier, +) from torchao.quantization.quant_api import ( _get_linear_subclass_inserter, _is_linear, _replace_with_custom_fn_if_matches_filter, ) +from torchao.sparsity.blocksparse import BlockSparseTensor # Sparsity helper functions @@ -31,6 +35,12 @@ def apply_fake_sparsity(model, **kwargs): sparsifier.squash_mask() +def block_sparse_weight(blocksize=64): + return _get_linear_subclass_inserter( + partial(BlockSparseTensor.from_dense, blocksize=blocksize) + ) + + def semi_sparse_weight(): """ Convert the weight of linear moduels to semi-structured (2:4) sparsity From 217d9688baf3f41de3225fafd0b717e3074e7482 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 19 Feb 2025 11:03:02 -0500 Subject: [PATCH 096/115] Make FakeQuantizer expose useful config details (#1717) **Summary:** Expose useful config details when printing FakeQuantizer, which appears when printing QAT prepared models containing linear layers. Before: ``` >>> print(prepared_model.layers[0].attn.qproj) FakeQuantizedLinear( in_features=4096, out_features=4096, bias=False (activation_fake_quantizer): FakeQuantizer() (weight_fake_quantizer): FakeQuantizer() ) ``` After: ``` >>> print(prepared_model.layers[0].attn.qproj) FakeQuantizedLinear( in_features=4096, out_features=4096, bias=False (activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=, is_dynamic=True, range_learning=False)) (weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=, is_dynamic=True, range_learning=False)) ) ``` **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantizer_repr --- test/quantization/test_qat.py | 18 ++++++++++++++++++ torchao/quantization/qat/fake_quantizer.py | 6 ++++++ 2 files changed, 24 insertions(+) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 82324394a8..9aeaa53664 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -31,6 +31,9 @@ from torchao.quantization.qat.embedding import ( FakeQuantizedEmbedding, ) +from torchao.quantization.qat.fake_quantizer import ( + FakeQuantizer, +) from torchao.quantization.qat.linear import ( FakeQuantizedLinear, Int4WeightOnlyQATLinear, @@ -1348,6 +1351,21 @@ def test_fake_quantize_config_torch_intx(self): out2 = linear2(*x2) torch.testing.assert_close(out1, out2, atol=0, rtol=0) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" + ) + def test_fake_quantizer_repr(self): + """ + Test that `repr(FakeQuantizer(config))` exposes useful config details. + """ + config = FakeQuantizeConfig(torch.int4, group_size=128) + fake_quantizer = FakeQuantizer(config) + fake_quantizer_repr = repr(fake_quantizer) + self.assertTrue("dtype=torch.int4" in fake_quantizer_repr) + self.assertTrue("group_size=128" in fake_quantizer_repr) + self.assertTrue("PerGroup" in fake_quantizer_repr) + self.assertTrue("MappingType.SYMMETRIC" in fake_quantizer_repr) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 15cd3aaca4..de747366a6 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -134,3 +134,9 @@ def _should_compute_qparams(self) -> bool: Return whether we need to compute new scales and zero points. """ return self.config.is_dynamic or self.scale is None or self.zero_point is None + + def __repr__(self) -> str: + """ + Return a human readable representation of this `FakeQuantizer` with config details. + """ + return "FakeQuantizer(%s)" % self.config From 4780e10d397e31cc13b6ca082e03eca34ef71024 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Wed, 19 Feb 2025 22:36:36 -0500 Subject: [PATCH 097/115] Update version.txt to 0.10.0 (#1714) Update version.txt --- version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.txt b/version.txt index ac39a106c4..78bc1abd14 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.9.0 +0.10.0 From f6f33220dae144f5ac682a52763f60856805cb25 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 20 Feb 2025 10:34:28 -0800 Subject: [PATCH 098/115] Add ukernel selection logic + clean up KleidiAI integration (#1652) * UKernel Selection, up, up, up, up * up --- .../workflows/torchao_experimental_test.yml | 17 +- setup.py | 3 +- ...i8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h | 122 ------ ...i8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h | 123 ------ ..._qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h | 120 ------ ..._qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h | 122 ------ .../kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h | 144 +++++-- .../cpu/aarch64/tests/build_and_run_tests.sh | 9 - .../kernels/cpu/aarch64/tests/test_linear.cpp | 332 -------------- .../embedding_xbit/packed_weights_header.h | 2 +- .../CMakeLists.txt | 12 + .../kernel_selector.h | 361 ++++++++++++++++ .../linear_8bit_act_xbit_weight.cpp | 70 +-- .../linear_8bit_act_xbit_weight.h | 45 +- .../op_linear_8bit_act_xbit_weight-impl.h | 94 ++-- .../packed_weights_header.h | 38 -- .../experimental/ops/packed_weights_header.h | 34 +- .../ops/tests/build_and_run_tests.sh | 3 + .../experimental/ops/tests/generate_tests.py | 10 + .../test_linear_8bit_act_xbit_weight.cpp | 406 +++++++++++++++--- 20 files changed, 982 insertions(+), 1085 deletions(-) delete mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h delete mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h delete mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h delete mode 100644 torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h create mode 100644 torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h delete mode 100644 torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 08f494c71d..e1511ffe9a 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -37,7 +37,22 @@ jobs: pip install numpy pip install pytest USE_CPP=1 pip install . - - name: Run tests + - name: Run python tests run: | conda activate venv pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py + python torchao/experimental/tests/test_embedding_xbit_quantizer.py + - name: Run kernels/cpu/aarch64/tests + run: | + conda activate venv + pushd torchao/experimental/kernels/cpu/aarch64/tests + sh build_and_run_tests.sh + rm -rf /tmp/cmake-out + popd + - name: Run torchao/experimental/ops/tests + run: | + conda activate venv + pushd torchao/experimental/ops/tests + sh build_and_run_tests.sh + rm -rf /tmp/cmake-out + popd diff --git a/setup.py b/setup.py index 6ee93bc9ab..357e0e491f 100644 --- a/setup.py +++ b/setup.py @@ -179,7 +179,8 @@ def build_cmake(self, ext): "cmake", ext.sourcedir, "-DCMAKE_BUILD_TYPE=" + build_type, - "-DTORCHAO_BUILD_EXECUTORCH_OPS=OFF", + # Disable now because 1) KleidiAI increases build time, and 2) KleidiAI has accuracy issues due to BF16 + "-DTORCHAO_BUILD_KLEIDIAI=OFF", "-DTorch_DIR=" + torch_dir, "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, ], diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h deleted file mode 100644 index 658a0feadc..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include - -#include - -namespace torchao::kernels::cpu::aarch64::kleidi { -namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { - -namespace neon_dotprod_1x4x32 { -const Ukernel get_ukernel() { - return Ukernel{ - .get_m_step = - kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_n_step = - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_mr = - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_nr = - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_kr = - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_sr = - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_lhs_packed_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_rhs_packed_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_dst_offset = - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .get_dst_size = - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - .run_matmul = - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}; -} - -size_t activation_data_size(int m, int k, int group_size) { - (void)group_size; // unused - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( - get_ukernel(), m, k); -} - -void prepare_activation_data( - void* prepared_activation_data, - int m, - int k, - int group_size, - const float* activations) { - (void)group_size; // unused - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), prepared_activation_data, m, k, activations); -} - -size_t weight_data_size(int n, int k, int group_size) { - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( - get_ukernel(), n, k, group_size); -} - -void prepare_weight_data( - void* prepared_weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( - get_ukernel(), - prepared_weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -void kernel( - float32_t* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max) { - if (clamp_min == 0 && clamp_max == 0) { - clamp_min = std::numeric_limits::lowest(); - clamp_max = std::numeric_limits::max(); - } - - auto ukernel = get_ukernel(); - ukernel.run_matmul( - m, - n, - k, - group_size, - activation_data, - weight_data, - output, - /*dst_stride_row=*/output_m_stride * sizeof(float), - /*dst_stride_col=*/sizeof(float), - clamp_min, - clamp_max); -} - -size_t get_preferred_alignement() { - return 16; -} -} // namespace neon_dotprod_1x4x32 -} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p -} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h deleted file mode 100644 index 336d5a8e7f..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include - -#include - -namespace torchao::kernels::cpu::aarch64::kleidi { -namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { -namespace neon_dotprod_1x8x32 { -const Ukernel get_ukernel() { - return Ukernel{ - .get_m_step = - kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_n_step = - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_mr = - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_nr = - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_kr = - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_sr = - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_lhs_packed_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_rhs_packed_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_dst_offset = - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .get_dst_size = - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - .run_matmul = - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}; -} - -size_t activation_data_size(int m, int k, int group_size) { - (void) group_size; // unused - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(get_ukernel(), m, k); -} - -void prepare_activation_data( - void* prepared_activation_data, - int m, - int k, - int group_size, - const float* activations) { - (void) group_size; // unused - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), - prepared_activation_data, - m, - k, - activations); -} - -size_t weight_data_size(int n, int k, int group_size) { - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(get_ukernel(), n, k, group_size); -} - -void prepare_weight_data( - void* prepared_weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( - get_ukernel(), - prepared_weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -void kernel( - float32_t* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max) { - if (clamp_min == 0 && clamp_max == 0) { - clamp_min = std::numeric_limits::lowest(); - clamp_max = std::numeric_limits::max(); - } - - auto ukernel = get_ukernel(); - ukernel.run_matmul( - m, - n, - k, - group_size, - activation_data, - weight_data, - output, - /*dst_stride_row=*/ output_m_stride * sizeof(float), - /*dst_stride_col=*/ sizeof(float), - clamp_min, - clamp_max); -} - -size_t get_preferred_alignement() { - return 16; -} -} // namespace neon_dotprod_1x4x32 -} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p -} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h deleted file mode 100644 index 60004704ed..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once -#include -#include - -namespace torchao::kernels::cpu::aarch64::kleidi { -namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { -namespace neon_i8mm_8x4x32 { - -const Ukernel get_ukernel() { - return Ukernel{ - .get_m_step = - kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_n_step = - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_mr = - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_nr = - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_kr = - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_sr = - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_lhs_packed_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_rhs_packed_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_dst_offset = - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .get_dst_size = - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - .run_matmul = - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm}; -} - -size_t activation_data_size(int m, int k, int group_size) { - (void)group_size; // unused - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( - get_ukernel(), m, k); -} - -void prepare_activation_data( - void* prepared_activation_data, - int m, - int k, - int group_size, - const float* activations) { - (void)group_size; // unused - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), prepared_activation_data, m, k, activations); -} - -size_t weight_data_size(int n, int k, int group_size) { - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( - get_ukernel(), n, k, group_size); -} - -void prepare_weight_data( - void* prepared_weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( - get_ukernel(), - prepared_weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -void kernel( - float32_t* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max) { - if (clamp_min == 0 && clamp_max == 0) { - clamp_min = std::numeric_limits::lowest(); - clamp_max = std::numeric_limits::max(); - } - - auto ukernel = get_ukernel(); - ukernel.run_matmul( - m, - n, - k, - group_size, - activation_data, - weight_data, - output, - /*dst_stride_row=*/output_m_stride * sizeof(float), - /*dst_stride_col=*/sizeof(float), - clamp_min, - clamp_max); -} - -size_t get_preferred_alignement() { - return 16; -} -} // namespace neon_i8mm_8x4x32 -} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p -} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h deleted file mode 100644 index 90db4ae3d6..0000000000 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once -#include - -#include - -namespace torchao::kernels::cpu::aarch64::kleidi { -namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { -namespace neon_i8mm_4x8x32 { - -const Ukernel get_ukernel() { - return Ukernel{ - .get_m_step = - kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_n_step = - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_mr = - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_nr = - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_kr = - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_sr = - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_lhs_packed_offset = - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_rhs_packed_offset = - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_dst_offset = - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .get_dst_size = - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - .run_matmul = - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}; -} - -size_t activation_data_size(int m, int k, int group_size) { - (void)group_size; // unused - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size( - get_ukernel(), m, k); -} - -void prepare_activation_data( - void* prepared_activation_data, - int m, - int k, - int group_size, - const float* activations) { - (void)group_size; // unused - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data( - get_ukernel(), prepared_activation_data, m, k, activations); -} - -size_t weight_data_size(int n, int k, int group_size) { - return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size( - get_ukernel(), n, k, group_size); -} - -void prepare_weight_data( - void* prepared_weight_data, - int n, - int k, - int group_size, - const int8_t* weight_qvals, - const float* weight_scales, - const int8_t* weight_zeros, - const float* bias) { - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data( - get_ukernel(), - prepared_weight_data, - n, - k, - group_size, - weight_qvals, - weight_scales, - weight_zeros, - bias); -} - -void kernel( - float32_t* output, - int output_m_stride, - int m, - int n, - int k, - int group_size, - const void* weight_data, - const void* activation_data, - float clamp_min, - float clamp_max) { - if (clamp_min == 0 && clamp_max == 0) { - clamp_min = std::numeric_limits::lowest(); - clamp_max = std::numeric_limits::max(); - } - - auto ukernel = get_ukernel(); - ukernel.run_matmul( - m, - n, - k, - group_size, - activation_data, - weight_data, - output, - /*dst_stride_row=*/output_m_stride * sizeof(float), - /*dst_stride_col=*/sizeof(float), - clamp_min, - clamp_max); -} - -size_t get_preferred_alignement() { - return 16; -} - -} // namespace neon_i8mm_4x8x32 -} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p -} // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h index 9cde684995..9071869fce 100644 --- a/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h +++ b/torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h @@ -14,8 +14,15 @@ #include #include +#include +#include #include +#ifdef TORCHAO_ENABLE_ARM_I8MM +#include +#include +#endif // TORCHAO_ENABLE_ARM_I8MM + #include namespace torchao::kernels::cpu::aarch64::kleidi { @@ -23,7 +30,9 @@ namespace torchao::kernels::cpu::aarch64::kleidi { // Helper functions // TODO: find a better place for these? -size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } +namespace internal { + +inline size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } uint16_t get_bf16_from_float(float f) { uint16_t bf16; @@ -37,46 +46,59 @@ uint16_t get_bf16_from_float(float f) { return bf16; } +// KleidiAI kernels require n is even, so we round up to next even number +// if required and pad +inline int adjust_n(int n) { return roundup(n, 2); } + +} // namespace internal + namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p { using Ukernel = struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel; -size_t activation_data_size(const Ukernel ukernel, int m, int k) { +template +size_t activation_data_size(int m, int k, int group_size) { + (void)group_size; // unused auto lhs_packing = get_lhs_packing(); - return lhs_packing.get_lhs_packed_size(m, k, ukernel.get_mr(), - ukernel.get_kr(), ukernel.get_sr()); + return lhs_packing.get_lhs_packed_size(m, k, mr, kr, sr); } -void prepare_activation_data(const Ukernel ukernel, void *activation_data, - int m, int k, const float *activations) { +template +void prepare_activation_data(void *activation_data, int m, int k, + int group_size, const float *activations) { + (void)group_size; // unused auto lhs_pack = get_lhs_packing(); - lhs_pack.run_lhs_pack(m, k, ukernel.get_mr(), ukernel.get_kr(), - ukernel.get_sr(), + lhs_pack.run_lhs_pack(m, k, mr, kr, sr, /*m_index_start=*/0, activations, /*lhs_stride=*/k * sizeof(float), activation_data); } -size_t weight_data_size(const Ukernel ukernel, int n, int k, int group_size) { +template +size_t weight_data_size(int n, int k, int group_size) { auto rhs_pack = get_rhs_packing(); - return rhs_pack.get_rhs_packed_size(n, k, ukernel.get_nr(), ukernel.get_kr(), - ukernel.get_sr(), group_size, + return rhs_pack.get_rhs_packed_size(n, k, nr, kr, sr, group_size, kai_datatype::kai_dt_bf16); } -void prepare_weight_data(const Ukernel ukernel, void *weight_data, int n, int k, - int group_size, const int8_t *weight_qvals, - const float *weight_scales, const int8_t *weight_zeros, - const float *bias) { - // TODO(T204312268) - remove this constraint and pad when possible - assert(n % 2 == 0); +template +void prepare_weight_data(void *weight_data, int n, int k, int group_size, + const int8_t *weight_qvals, const float *weight_scales, + const int8_t *weight_zeros, const float *bias) { - assert(group_size % 32 == 0); - assert(k % group_size == 0); + if (group_size % 32 != 0) { + throw std::runtime_error( + "Group size must be a multiple of 32, but got group_size=" + + std::to_string(group_size)); + } + if (k % group_size != 0) { + throw std::runtime_error( + "k must be a multiple of group size, but got k=" + std::to_string(k) + + " and group_size=" + std::to_string(group_size)); + } // TODO SIMDify this size_t n_groups = n * k / group_size; - auto weight_scales_bf16 = std::vector(n_groups, 0); // We don't support weight zeros yet if (weight_zeros != nullptr) { @@ -85,18 +107,29 @@ void prepare_weight_data(const Ukernel ukernel, void *weight_data, int n, int k, } } + auto weight_scales_bf16_padded = + std::vector(internal::adjust_n(n) * k / group_size, 0); for (size_t i = 0; i < n_groups; i++) { - weight_scales_bf16[i] = get_bf16_from_float(weight_scales[i]); + weight_scales_bf16_padded[i] = + internal::get_bf16_from_float(weight_scales[i]); } // Prepack weights before packing // TODO SIMDify this - auto packed_weight_qvals = std::vector(n * k / 2, 0); + auto packed_weight_qvals_padded = + std::vector(internal::adjust_n(n) * k / 2, 0); uint8_t wzp = 8; for (size_t i = 0; i < n * k; i += 2) { const uint8_t low = static_cast(weight_qvals[i] + wzp); const uint8_t high = static_cast(weight_qvals[i + 1] + wzp); - packed_weight_qvals[i / 2] = ((high << 4) | (low & 0xF)); + packed_weight_qvals_padded[i / 2] = ((high << 4) | (low & 0xF)); + } + + auto bias_padded = std::vector(internal::adjust_n(n), 0.0); + if (bias != nullptr) { + for (size_t i = 0; i < n; i++) { + bias_padded[i] = bias[i]; + } } // Parameters for packing @@ -107,17 +140,68 @@ void prepare_weight_data(const Ukernel ukernel, void *weight_data, int n, int k, auto rhs_pack = get_rhs_packing(); rhs_pack.run_rhs_pack( - /*groups=*/1, n, k, ukernel.get_nr(), ukernel.get_kr(), ukernel.get_sr(), - group_size, - /*rhs=*/reinterpret_cast(packed_weight_qvals.data()), - /*rhs_stride=*/roundup(k, 2) / 2, - /*bias=*/bias, - /*scale=*/reinterpret_cast(weight_scales_bf16.data()), - /*scale_stride=*/sizeof(uint16_t) * (roundup(k, group_size) / group_size), + /*groups=*/1, internal::adjust_n(n), k, nr, kr, sr, group_size, + /*rhs=*/ + reinterpret_cast(packed_weight_qvals_padded.data()), + /*rhs_stride=*/internal::roundup(k, 2) / 2, + /*bias=*/reinterpret_cast(bias_padded.data()), + /*scale=*/ + reinterpret_cast(weight_scales_bf16_padded.data()), + /*scale_stride=*/sizeof(uint16_t) * + (internal::roundup(k, group_size) / group_size), /*rhs_packed=*/weight_data, /*extra_bytes=*/0, /*qparams=*/&qparams); } +size_t get_preferred_alignement() { return 16; } + +#define DEFINE_KERNEL_STRUCT(name) \ + struct name { \ + inline static kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel \ + get_ukernel() { \ + return kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel( \ + {.get_m_step = kai_get_m_step_##name, \ + .get_n_step = kai_get_n_step_##name, \ + .get_mr = kai_get_mr_##name, \ + .get_nr = kai_get_nr_##name, \ + .get_kr = kai_get_kr_##name, \ + .get_sr = kai_get_sr_##name, \ + .get_lhs_packed_offset = kai_get_lhs_packed_offset_##name, \ + .get_rhs_packed_offset = kai_get_rhs_packed_offset_##name, \ + .get_dst_offset = kai_get_dst_offset_##name, \ + .get_dst_size = kai_get_dst_size_##name, \ + .run_matmul = kai_run_##name}); \ + } \ + inline static void kernel(float32_t *output, int output_m_stride, int m, \ + int n, int k, int group_size, \ + const void *weight_data, \ + const void *activation_data, float clamp_min, \ + float clamp_max) { \ + if (clamp_min == 0 && clamp_max == 0) { \ + clamp_min = std::numeric_limits::lowest(); \ + clamp_max = std::numeric_limits::max(); \ + } \ + get_ukernel().run_matmul( \ + m, internal::adjust_n(n), k, group_size, activation_data, \ + weight_data, output, \ + /*dst_stride_row=*/output_m_stride * sizeof(float), \ + /*dst_stride_col=*/sizeof(float), /*clamp_min=*/clamp_min, \ + /*clamp_max=*/clamp_max); \ + } \ + } + +DEFINE_KERNEL_STRUCT( + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod); +DEFINE_KERNEL_STRUCT( + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod); + +#ifdef TORCHAO_ENABLE_ARM_I8MM +DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm); +DEFINE_KERNEL_STRUCT(matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm); +#endif // TORCHAO_ENABLE_ARM_I8MM + +#undef DEFINE_KERNEL_STRUCT + } // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p } // namespace torchao::kernels::cpu::aarch64::kleidi diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 5c12d7184e..39cc76d887 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -12,8 +12,6 @@ export CMAKE_OUT=/tmp/cmake-out/torch_ao/kernel_tests target=${1:-"native"} -IS_ARM64=0 -BUILD_ARM_I8MM=0 EXTRA_ARGS="" if [[ "${target}" == "android" ]]; then if [[ -z ${ANDROID_NDK} ]]; then @@ -38,17 +36,10 @@ if [[ "${target}" == "android" ]]; then echo "Building tests for Android (${android_abi}) @ ${CMAKE_OUT}" fi -hash arch; retval=$? -if [[ ${retval} -eq 0 && $(arch) == "arm64" ]]; then - IS_ARM64=1 -fi - cmake \ ${EXTRA_ARGS} \ -DCMAKE_BUILD_TYPE=Debug \ -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ - -DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \ - -DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \ -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests \ -B ${CMAKE_OUT} diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp index 070e7bebfb..073e612c68 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_linear.cpp @@ -14,15 +14,6 @@ #include #include -#ifdef TORCHAO_ENABLE_KLEIDI -#include -#include -#ifdef TORCHAO_ENABLE_ARM_I8MM -#include -#include -#endif // TORCHAO_ENABLE_ARM_I8MM -#endif // TORCHAO_ENABLE_KLEIDI - float kTol = 0.0001; template @@ -269,327 +260,4 @@ TEST( } } -#ifdef TORCHAO_ENABLE_KLEIDI -template -void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod( - int m, int k, int n, int group_size) { - auto test_case = - torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: - generate(m, k, n, group_size, - /*weight_nbit=*/4, - /*has_weight_zeros*/ false, has_bias, has_clamp, - /*weight_scale_bf16_round_trip=*/true); - - using namespace torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32; - - std::vector activation_data(activation_data_size(m, k, group_size)); - - prepare_activation_data((void *)activation_data.data(), m, k, group_size, - test_case.activations.data()); - - std::vector weight_data(weight_data_size(n, k, group_size)); - - prepare_weight_data((void *)weight_data.data(), n, k, group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); - - std::vector output(m * n); - kernel(output.data(), - /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); - - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); - } -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - k_eq_gs_32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - large_k_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - even_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, - m_clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -template -void test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod( - int m, int k, int n, int group_size) { - auto test_case = - torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: - generate(m, k, n, group_size, - /*weight_nbit=*/4, - /*has_weight_zeros=*/false, has_bias, has_clamp, - /*round_weight_scales_to_bf16=*/true); - - using namespace torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32; - - std::vector activation_data(activation_data_size(m, k, group_size)); - - prepare_activation_data((void *)activation_data.data(), m, k, group_size, - test_case.activations.data()); - - std::vector weight_data(weight_data_size(n, k, group_size)); - - prepare_weight_data((void *)weight_data.data(), n, k, group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); - - std::vector output(m * n); - kernel(output.data(), - /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); - - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); - } -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - k_eq_gs_32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - large_k_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - even_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - m_clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -#ifdef TORCHAO_ENABLE_ARM_I8MM -template -void test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm( - int m, int k, int n, int group_size) { - auto test_case = - torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: - generate(m, k, n, group_size, - /*weight_nbit=*/4, - /*has_weight_zeros=*/false, has_bias, has_clamp, - /*round_weight_scales_to_bf16=*/true); - - using namespace torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32; - - std::vector activation_data(activation_data_size(m, k, group_size)); - - prepare_activation_data((void *)activation_data.data(), m, k, group_size, - test_case.activations.data()); - - std::vector weight_data(weight_data_size(n, k, group_size)); - - prepare_weight_data((void *)weight_data.data(), n, k, group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); - - std::vector output(m * n); - kernel(output.data(), - /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); - - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); - } -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - k_eq_gs_32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - large_k_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - even_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, - m_clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -template -void test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm( - int m, int k, int n, int group_size) { - auto test_case = - torchao::channelwise_8bit_activation_groupwise_lowbit_weight_test_case:: - generate(m, k, n, group_size, - /*weight_nbit=*/4, - /*has_weight_zeros=*/false, has_bias, has_clamp, - /*round_weight_scales_to_bf16=*/true); - - using namespace torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32; - - std::vector activation_data(activation_data_size(m, k, group_size)); - - prepare_activation_data((void *)activation_data.data(), m, k, group_size, - test_case.activations.data()); - - std::vector weight_data(weight_data_size(n, k, group_size)); - - prepare_weight_data((void *)weight_data.data(), n, k, group_size, - test_case.weight_qvals.data(), - test_case.weight_scales.data(), - /*weight_zeros=*/test_case.weight_zeros.data(), - /*bias=*/test_case.bias.data()); - - std::vector output(m * n); - kernel(output.data(), - /*output_m_stride=*/n, m, n, k, group_size, weight_data.data(), - activation_data.data(), - /*clamp_min=*/test_case.clamp_min, - /*clamp_max=*/test_case.clamp_max); - - for (int i = 0; i < m * n; i++) { - EXPECT_NEAR(output[i], test_case.expected_output[i], kTol); - } -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - k_eq_gs_32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/32, /*n=*/4, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - large_k_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/512, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - even_n_gs32) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/1024, /*n=*/182, /*group_size=*/32); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, false /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/1, /*k=*/128, /*n=*/182, /*group_size=*/128); -} - -TEST(test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - m_clamp_k_eq_gs128) { - test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm< - false /*has_bias*/, true /*has_clamp*/>( - /*m=*/11, /*k=*/128, /*n=*/182, /*group_size=*/128); -} -#endif // TORCHAO_ENABLE_ARM_I8MM -#endif // TORCHAO_ENABLE_KLEIDI #endif // defined(__aarch64__) || defined(__ARM_NEON) diff --git a/torchao/experimental/ops/embedding_xbit/packed_weights_header.h b/torchao/experimental/ops/embedding_xbit/packed_weights_header.h index 935ee3bfbd..8e47c2d1c0 100644 --- a/torchao/experimental/ops/embedding_xbit/packed_weights_header.h +++ b/torchao/experimental/ops/embedding_xbit/packed_weights_header.h @@ -16,7 +16,7 @@ inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( int max_value_chunk_size, int version = 1) { return torchao::ops::PackedWeightsHeader( - torchao::ops::PackedWeightsFormat::embedding_xbit_universal, + torchao::ops::PackedWeightsType::embedding_xbit_universal, {version, weight_nbit, min_value_chunk_size, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt b/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt index 91fcf60621..82d9fa2cf3 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt @@ -8,6 +8,16 @@ cmake_minimum_required(VERSION 3.19) include(${CMAKE_CURRENT_SOURCE_DIR}/../../Utils.cmake) + # For some reason cpuinfo package has unused functions/variables + # TODO (T215533422): fix upstream +add_compile_options(-Wno-unused-function -Wno-unused-variable) +include(FetchContent) +FetchContent_Declare(cpuinfo + GIT_REPOSITORY https://github.com/pytorch/cpuinfo.git + GIT_TAG aaac07ee499895770c89163ce0920ef8bb41ed23) +FetchContent_MakeAvailable( + cpuinfo) + find_package(Torch REQUIRED) add_library(torchao_ops_linear_8bit_act_xbit_weight_aten OBJECT linear_8bit_act_xbit_weight.cpp @@ -15,6 +25,7 @@ add_library(torchao_ops_linear_8bit_act_xbit_weight_aten OBJECT ) target_link_torchao_parallel_backend(torchao_ops_linear_8bit_act_xbit_weight_aten aten_openmp) target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE torchao_kernels_aarch64) +target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE cpuinfo) target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}") target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}") target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE USE_ATEN=1) @@ -37,4 +48,5 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS) target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1) target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}") target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE torchao_kernels_aarch64) + target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE cpuinfo) endif() diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h new file mode 100644 index 0000000000..443d903dfb --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -0,0 +1,361 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once +#include +#include +#include + +#if defined(__aarch64__) || defined(__ARM_NEON) +#include +#endif // defined(__aarch64__) || defined(__ARM_NEON) + +#include +#include +#include + +#if defined(TORCHAO_ENABLE_KLEIDI) +#include +#endif // TORCHAO_ENABLE_KLEIDI + +namespace torchao::ops::linear_8bit_act_xbit_weight { + +struct PackedWeightsFormat { + torchao::ops::PackedWeightsType type; + int weight_nbit; + bool has_weight_zeros; + bool has_bias; + int nr; + int kr; + int sr; + + PackedWeightsFormat(torchao::ops::PackedWeightsType type, int weight_nbit, + bool has_weight_zeros, bool has_bias, int nr, int kr, + int sr) + : type{type}, weight_nbit{weight_nbit}, + has_weight_zeros{has_weight_zeros}, has_bias{has_bias}, nr{nr}, kr{kr}, + sr{sr} {} + + static PackedWeightsFormat + from_packed_weights_header(torchao::ops::PackedWeightsHeader header) { + return PackedWeightsFormat( + header.type, header.params[0], static_cast(header.params[1]), + static_cast(header.params[2]), header.params[3], header.params[4], + header.params[5]); + } + + inline torchao::ops::PackedWeightsHeader to_packed_weights_header() const { + return torchao::ops::PackedWeightsHeader( + type, {weight_nbit, has_weight_zeros, has_bias, nr, kr, sr}); + } +}; + +struct UKernelConfigRegistrationTable { +private: + using Key = std::pair; + struct KeyHasher { + std::size_t operator()(const Key &k) const { + return std::hash()(k.first) ^ + std::hash()(static_cast(k.second)); + } + }; + std::unordered_map registration_table_; + inline Key make_key(torchao::ops::PackedWeightsHeader header, + cpuinfo_uarch uarch) const { + return std::make_pair(header, uarch); + } + +public: + void register_ukernel_config(PackedWeightsFormat format, cpuinfo_uarch uarch, + UKernelConfig config) { + auto header = format.to_packed_weights_header(); + auto key = make_key(header, uarch); + if (registration_table_.find(key) != registration_table_.end()) { + throw std::runtime_error( + "UKernelConfig is already registered for this format"); + } + registration_table_[key] = config; + } + std::optional + get_ukernel_config(torchao::ops::PackedWeightsHeader header, + cpuinfo_uarch uarch) const { + auto key = make_key(header, uarch); + auto it = registration_table_.find(key); + if (it == registration_table_.end()) { + return std::nullopt; + } + return it->second; + } +}; + +template +void check_format(PackedWeightsFormat format, + torchao::ops::PackedWeightsType type) { + if (format.type != type) { + throw std::runtime_error("Kernel expects packed_weights type=" + + std::to_string(static_cast(type)) + + ", but got packed_weights with type=" + + std::to_string(static_cast(format.type))); + } + if (format.weight_nbit != weight_nbit) { + throw std::runtime_error( + "Kernel expects weight_nbit=" + std::to_string(weight_nbit) + + ", but got packed_weights with weight_nbit=" + + std::to_string(format.weight_nbit)); + } + if (format.has_weight_zeros != has_weight_zeros) { + throw std::runtime_error( + "Kernel expects has_weight_zeros=" + std::to_string(has_weight_zeros) + + ", but got packed_weights with has_weight_zeros=" + + std::to_string(format.has_weight_zeros)); + } + if (format.has_bias != has_bias) { + throw std::runtime_error( + "Kernel expects has_bias=" + std::to_string(has_bias) + + ", but got packed_weights with has_bias=" + + std::to_string(format.has_bias)); + } +} + +template +void register_ukernel_config_universal(UKernelConfigRegistrationTable &table, + PackedWeightsFormat format, + cpuinfo_uarch uarch) { + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + check_format( + format, + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal); + + if (format.nr == 8 && format.kr == 16 && format.sr == 2) { +#if defined(__aarch64__) || defined(__ARM_NEON) + if (cpuinfo_has_arm_neon_dot()) { + namespace kernel = torchao::kernels::cpu::aarch64::linear:: + channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; + table.register_ukernel_config( + format, uarch, + UKernelConfig{ + /*preferred_alignment*/ 16, + /*nr*/ 8, + /*weight_packing_config*/ + {/*weight_data_size_fn*/ + &kernel::weight_data_size, + /*prepare_weight_data_fn*/ + &kernel::prepare_weight_data}, + /*linear_configs*/ + {{{/*mr*/ 1, + /*activation_data_size_fn*/ + &kernel::activation_data_size, + /*prepare_activation_data_fn*/ + &kernel::prepare_activation_data, + /*kernel*/ + &kernel::kernel}}}}); + return; + } +#endif // defined(__aarch64__) || defined(__ARM_NEON) + } +} + +#if defined(TORCHAO_ENABLE_KLEIDI) +template +UKernelConfig::linear_config_type get_linear_config_kleidi() { + namespace op = torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + assert(m_step == kernel_struct::get_ukernel().get_m_step()); + assert(mr == kernel_struct::get_ukernel().get_mr()); + assert(n_step == kernel_struct::get_ukernel().get_n_step()); + assert(nr == kernel_struct::get_ukernel().get_nr()); + assert(kr == kernel_struct::get_ukernel().get_kr()); + assert(sr == kernel_struct::get_ukernel().get_sr()); + return UKernelConfig::linear_config_type{ + /*mr*/ m_step, + /*activation_data_size_fn*/ &op::activation_data_size, + /*prepare_activation_data_fn*/ &op::prepare_activation_data, + /*kernel*/ &kernel_struct::kernel}; +} + +template +UKernelConfig::weight_packing_config_type get_weight_packing_config_kleidi() { + namespace op = torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + return UKernelConfig::weight_packing_config_type( + {/*weight_data_size_fn*/ &op::weight_data_size, + /*prepare_weight_data_fn*/ &op::prepare_weight_data}); +} + +template +void register_ukernel_config_kleidi(UKernelConfigRegistrationTable &table, + PackedWeightsFormat format, + cpuinfo_uarch uarch) { + if (!cpuinfo_initialize()) { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + check_format( + format, torchao::ops::PackedWeightsType::kleidi_ai); + namespace op = torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + + if (format.nr == 8 && format.kr == 16 && format.sr == 2) { + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; +#if defined(TORCHAO_ENABLE_ARM_I8MM) + if (cpuinfo_has_arm_i8mm()) { + constexpr int n_step = 8; + table.register_ukernel_config( + format, uarch, + UKernelConfig{ + /*preferred_alignment*/ op::get_preferred_alignement(), + /*nr*/ n_step, + /*weight_packing_config*/ + get_weight_packing_config_kleidi(), + /*linear_configs*/ + {{get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + /*m_step*/ 4, /*mr*/ 4, n_step, nr, kr, sr>()}}}); + return; + } +#endif // TORCHAO_ENABLE_ARM_I8MM + + if (cpuinfo_has_arm_neon_dot()) { + constexpr int n_step = 8; + table.register_ukernel_config( + format, uarch, + UKernelConfig{ + /*preferred_alignment*/ op::get_preferred_alignement(), + /*nr*/ n_step, + /*weight_packing_config*/ + get_weight_packing_config_kleidi(), + /*linear_configs*/ + {{get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + /*m_step*/ 1, /*mr*/ 1, n_step, nr, kr, sr>()}}}); + return; + } + } + + if (format.nr == 4 && format.kr == 16 && format.sr == 2) { + constexpr int nr = 4; + constexpr int kr = 16; + constexpr int sr = 2; + if (cpuinfo_has_arm_neon_dot()) { + constexpr int n_step = 4; + table.register_ukernel_config( + format, uarch, + UKernelConfig{ + /*preferred_alignment*/ op::get_preferred_alignement(), + /*nr*/ n_step, + /*weight_packing_config*/ + get_weight_packing_config_kleidi(), + /*linear_configs*/ + {{get_linear_config_kleidi< + op::matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + /*m_step*/ 1, /*mr*/ 1, n_step, nr, kr, sr>()}}}); + return; + } + } +} +#endif // TORCHAO_ENABLE_KLEIDI + +template +void register_ukernel_config(UKernelConfigRegistrationTable &table, + PackedWeightsFormat format, cpuinfo_uarch uarch) { + switch (format.type) { + case torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal: { + if (format.has_bias) { + register_ukernel_config_universal( + table, format, uarch); + } else { + register_ukernel_config_universal(table, format, + uarch); + } + break; + } + case torchao::ops::PackedWeightsType::kleidi_ai: { +#ifdef TORCHAO_ENABLE_KLEIDI + register_ukernel_config_kleidi(table, format, + uarch); +#endif // TORCHAO_ENABLE_KLEIDI + break; + } + default: + throw std::runtime_error( + "No registration available for packed_weights_type=" + + std::to_string(static_cast(format.type))); + } + + auto config = + table.get_ukernel_config(format.to_packed_weights_header(), uarch); + if (!config.has_value()) { + throw std::runtime_error("ukernel_config did not register"); + } +} + +// Not thread safe +template +UKernelConfig select_ukernel_config(torchao::ops::PackedWeightsHeader header) { + static UKernelConfigRegistrationTable table; + + // In future, we can populate this with the current thread's uarch + // That will require that select_ukernel_config be called in the lambda + // instead of before it on the main thread + // Note, cpuinfo_get_current_core() is not currently implemeted outside of + // linux XNNPACK often uses non-core specific logic like + // cpuinfo_get_core(0)->uarch in configs + auto uarch = cpuinfo_uarch_unknown; + auto ukernel = table.get_ukernel_config(header, uarch); + if (ukernel.has_value()) { + return ukernel.value(); + } + + auto format = PackedWeightsFormat::from_packed_weights_header(header); + register_ukernel_config(table, format, uarch); + + ukernel = table.get_ukernel_config(header, uarch); + assert(ukernel.has_value()); + return ukernel.value(); +} + +template +UKernelConfig select_ukernel_config(PackedWeightsFormat format) { + return select_ukernel_config( + format.to_packed_weights_header()); +} + +template +PackedWeightsFormat +select_packed_weights_format(std::optional target = std::nullopt) { +// Select KleidiAI format +#if defined(TORCHAO_ENABLE_KLEIDI) + if (!target || *target == "kleidi_ai") { + if constexpr (weight_nbit == 4 && + (!has_weight_zeros)) { // TODO: add has_bias here + return PackedWeightsFormat( + torchao::ops::PackedWeightsType::kleidi_ai, weight_nbit, + has_weight_zeros, /*has_bias*/ true, /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); + } + } +#endif // defined(TORCHAO_ENABLE_KLEIDI) + + // Select universal format + if (!target || *target == "universal") { + return PackedWeightsFormat( + torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal, + weight_nbit, has_weight_zeros, has_bias, /*nr*/ 8, /*kr*/ 16, /*sr*/ 2); + } + + throw std::runtime_error("No packed_weights_format was selected"); +} + +} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp index 709386998e..1c23bdbbae 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.cpp @@ -31,7 +31,7 @@ PackWeightDataTilingParams get_default_pack_weight_data_tiling_params( assert(nc >= 1); // Replace nc with the next number nr divides - nc = ((nc + ukernel_config.nr - 1) / ukernel_config.nr) * ukernel_config.nr; + nc = ((nc + nr - 1) / nr) * nr; tiling_params.nc_by_nr = nc / nr; return tiling_params; @@ -59,16 +59,25 @@ void pack_weight_data_operator(const UKernelConfig &ukernel_config, int nc_tile_size = std::min(nc, n - n_idx); int weight_data_offset = - (n_idx / nr) * ukernel_config.weight_data_size_fn(nr, k, group_size); + (n_idx / nr) * ukernel_config.weight_packing_config.weight_data_size_fn( + nr, k, group_size); int weight_qvals_offset = n_idx * k; int weight_scales_and_zeros_offset = (n_idx * k / group_size); - int bias_offset = n_idx; - ukernel_config.prepare_weight_data_fn( + const int8_t *weight_zeros_ptr = nullptr; + if (weight_zeros != nullptr) { + weight_zeros_ptr = weight_zeros + weight_scales_and_zeros_offset; + } + const float *bias_ptr = nullptr; + if (bias != nullptr) { + bias_ptr = bias + n_idx; + } + + ukernel_config.weight_packing_config.prepare_weight_data_fn( (char *)weight_data + weight_data_offset, /*n=*/nc_tile_size, k, group_size, weight_qvals + weight_qvals_offset, - weight_scales + weight_scales_and_zeros_offset, - weight_zeros + weight_scales_and_zeros_offset, bias + bias_offset); + weight_scales + weight_scales_and_zeros_offset, weight_zeros_ptr, + bias_ptr); }); } @@ -86,7 +95,7 @@ get_default_linear_tiling_params(const UKernelConfig &ukernel_config, int m, TORCHAO_CHECK(num_threads >= 1, "num_threads must be >= 1"); tiling_params.mc_by_mr = 1; - int mc = tiling_params.mc_by_mr * ukernel_config.mr; + int mc = tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr; int num_mc_panels = (m + mc - 1) / mc; int numerator = n * num_mc_panels; @@ -97,9 +106,10 @@ get_default_linear_tiling_params(const UKernelConfig &ukernel_config, int m, assert(nc >= 1); // Replace nc with next number nr divides - nc = ((nc + ukernel_config.nr - 1) / ukernel_config.nr) * ukernel_config.nr; - assert(nc % ukernel_config.nr == 0); - tiling_params.nc_by_nr = nc / ukernel_config.nr; + int nr = ukernel_config.nr; + nc = ((nc + nr - 1) / nr) * nr; + assert(nc % nr == 0); + tiling_params.nc_by_nr = nc / nr; assert(tiling_params.mc_by_mr >= 1); assert(tiling_params.nc_by_nr >= 1); @@ -112,15 +122,17 @@ inline size_t get_activation_data_buffer_size_with_tile_schedule_policy_single_mc_parallel_nc( const UKernelConfig &ukernel_config, const LinearTilingParams &tiling_params, int m, int k, int group_size) { - return ukernel_config.activation_data_size_fn( - tiling_params.mc_by_mr * ukernel_config.mr, k, group_size); + return ukernel_config.linear_configs[0].activation_data_size_fn( + tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr, k, + group_size); } inline size_t get_activation_data_buffer_size_with_tile_schedule_policy_parallel_mc_parallel_nc( const UKernelConfig &ukernel_config, const LinearTilingParams &tiling_params, int m, int k, int group_size) { - return ukernel_config.activation_data_size_fn(m, k, group_size); + return ukernel_config.linear_configs[0].activation_data_size_fn(m, k, + group_size); } inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( @@ -134,20 +146,22 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( // Ignored if has_clamp = false float clamp_min, float clamp_max) { int nr = ukernel_config.nr; - int mc = std::min(m, tiling_params.mc_by_mr * ukernel_config.mr); - int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); + int mc = + std::min(m, tiling_params.mc_by_mr * ukernel_config.linear_configs[0].mr); + int nc = std::min(n, tiling_params.nc_by_nr * nr); int num_mc_panels = (m + mc - 1) / mc; int num_nc_panels = (n + nc - 1) / nc; size_t weight_data_size = - ukernel_config.weight_data_size_fn(nr, k, group_size); + ukernel_config.weight_packing_config.weight_data_size_fn(nr, k, + group_size); for (int mc_tile_idx = 0; mc_tile_idx < num_mc_panels; mc_tile_idx++) { int m_idx = mc_tile_idx * mc; int mc_tile_size = std::min(mc, m - m_idx); int activations_offset = m_idx * k; - ukernel_config.prepare_activation_data_fn(activation_data_buffer, - /*m=*/mc_tile_size, k, group_size, - activations + activations_offset); + ukernel_config.linear_configs[0].prepare_activation_data_fn( + activation_data_buffer, + /*m=*/mc_tile_size, k, group_size, activations + activations_offset); torchao::parallel_1d(0, num_nc_panels, [&](int64_t idx) { int nc_tile_idx = idx; @@ -157,7 +171,7 @@ inline void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc( int output_offset = m_idx * n + n_idx; int weight_data_offset = (n_idx / nr) * weight_data_size; - ukernel_config.kernel_fn( + ukernel_config.linear_configs[0].kernel_fn( output + output_offset, /*output_m_stride=*/n, /*m=*/mc_tile_size, @@ -176,17 +190,19 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( // Inputs int m, int n, int k, int group_size, const void *weight_data, const float *activations, float clamp_min, float clamp_max) { - int mr = ukernel_config.mr; + int mr = ukernel_config.linear_configs[0].mr; int nr = ukernel_config.nr; - int mc = std::min(m, tiling_params.mc_by_mr * ukernel_config.mr); - int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr); + int mc = std::min(m, tiling_params.mc_by_mr * mr); + int nc = std::min(n, tiling_params.nc_by_nr * nr); int num_mc_panels = (m + mc - 1) / mc; int num_nc_panels = (n + nc - 1) / nc; size_t weight_data_size = - ukernel_config.weight_data_size_fn(nr, k, group_size); + ukernel_config.weight_packing_config.weight_data_size_fn(nr, k, + group_size); size_t activation_data_size = - ukernel_config.activation_data_size_fn(mr, k, group_size); + ukernel_config.linear_configs[0].activation_data_size_fn(mr, k, + group_size); torchao::parallel_1d(0, num_mc_panels, [&](int64_t idx) { int mc_tile_idx = idx; @@ -195,7 +211,7 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( int activations_offset = m_idx * k; int activation_data_offset = (m_idx / mr) * activation_data_size; - ukernel_config.prepare_activation_data_fn( + ukernel_config.linear_configs[0].prepare_activation_data_fn( activation_data_buffer + activation_data_offset, /*m=*/mc_tile_size, k, group_size, activations + activations_offset); }); @@ -213,7 +229,7 @@ inline void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc( int output_offset = m_idx * n + n_idx; int weight_data_offset = (n_idx / nr) * weight_data_size; - ukernel_config.kernel_fn( + ukernel_config.linear_configs[0].kernel_fn( output + output_offset, /*output_m_stride=*/n, /*m=*/mc_tile_size, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h index 1dc69dee74..6742f88b02 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #pragma once +#include #include #include #include @@ -29,27 +30,24 @@ struct UKernelConfig { const void *activation_data, float clamp_min, float clamp_max); - activation_data_size_fn_type activation_data_size_fn{nullptr}; - // preferred_activation_data_alignment is only a preferred alignment for - // performance reasons. Integration surfaces are not required to - // respect this alignment, and the ukernel must behave correctly no matter - // how the prepared_activation_data byte-array is aligned - size_t preferred_activation_data_alignment{0}; - prepare_activation_data_fn_type prepare_activation_data_fn{nullptr}; - - weight_data_size_fn_type weight_data_size_fn{nullptr}; - // weight_data_alignment is only a preferred alignment for - // performance reasons. Integration surfaces are not required to - // respect this alignment, and the ukernel must behave correctly no matter - // how the prepared_weight_data byte-array is aligned - size_t preferred_weight_data_alignment{0}; - prepare_weight_data_fn_type prepare_weight_data_fn{nullptr}; - - kernel_fn_type kernel_fn{nullptr}; - int mr{0}; + struct weight_packing_config_type { + weight_data_size_fn_type weight_data_size_fn{nullptr}; + prepare_weight_data_fn_type prepare_weight_data_fn{nullptr}; + }; + struct linear_config_type { + int mr{0}; + activation_data_size_fn_type activation_data_size_fn{nullptr}; + prepare_activation_data_fn_type prepare_activation_data_fn{nullptr}; + kernel_fn_type kernel_fn{nullptr}; + }; + + // preferred_alignment for activation and weight data + // Integration surfaces are not required to respect this alignment, and the + // ukernel must behave correctly no matter how buffers are aligned + size_t preferred_alignment{0}; int nr{0}; - - torchao::ops::PackedWeightsHeader packed_weights_header; + weight_packing_config_type weight_packing_config; + std::array linear_configs; }; // Pack weight functions @@ -64,12 +62,13 @@ get_default_pack_weight_data_tiling_params(const UKernelConfig &ukernel_config, inline size_t get_packed_weight_data_size(const UKernelConfig &ukernel_config, int n, int k, int group_size) { - return ukernel_config.weight_data_size_fn(n, k, group_size); + return ukernel_config.weight_packing_config.weight_data_size_fn(n, k, + group_size); } inline size_t get_preferred_packed_weight_data_alignment( const UKernelConfig &ukernel_config) { - return ukernel_config.preferred_weight_data_alignment; + return ukernel_config.preferred_alignment; } void pack_weight_data_operator(const UKernelConfig &ukernel_config, @@ -105,7 +104,7 @@ get_activation_data_buffer_size(const UKernelConfig &ukernel_config, inline size_t get_preferred_activation_data_buffer_alignment( const UKernelConfig &ukernel_config) { - return ukernel_config.preferred_activation_data_alignment; + return ukernel_config.preferred_alignment; } void linear_operator(const UKernelConfig &ukernel_config, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index bc88c0b725..364dd7b668 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -12,67 +12,13 @@ #include #include +#include #include -#include #include #include namespace { -// This selects a UkernelConfig based on the packed weight header -template -inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig -get_ukernel_config(torchao::ops::PackedWeightsHeader header) { - torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig config; - - switch (header.format) { -#if defined(__aarch64__) || defined(__ARM_NEON) - case torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal: - namespace ukernel - = torchao::kernels::cpu::aarch64::linear:: - channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - - // Check packing params match the kernel - TORCHAO_CHECK(header == torchao::ops::linear_8bit_act_xbit_weight:: - get_packed_weights_header_universal( - weight_nbit, has_weight_zeros, has_bias, - /*nr=*/8, - /*kr=*/16), - "Packing params do not match what kernel supports"); - - config.packed_weights_header = header; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.preferred_activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.preferred_weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - return config; - break; -#endif // defined(__aarch64__) || defined(__ARM_NEON) - default: - TORCHAO_CHECK(false, "Unsupported packed weights format"); - } -} - -template -inline torchao::ops::linear_8bit_act_xbit_weight::UKernelConfig -get_ukernel_config() { - auto header = torchao::ops::linear_8bit_act_xbit_weight:: - get_packed_weights_header_universal(weight_nbit, has_weight_zeros, - has_bias, /*nr=*/8, /*kr=*/16); - return get_ukernel_config( - header); -} - #ifdef USE_ATEN template Tensor pack_weights_cpu(const Tensor &weight_qvals, const Tensor &weight_scales, @@ -114,8 +60,12 @@ Tensor pack_weights_cpu(const Tensor &weight_qvals, const Tensor &weight_scales, using namespace torchao::ops::linear_8bit_act_xbit_weight; - auto ukernel_config = get_ukernel_config(); + auto packed_weights_format = + select_packed_weights_format(); + auto packed_weights_header = packed_weights_format.to_packed_weights_header(); + auto ukernel_config = select_ukernel_config( + packed_weights_header); + auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params( ukernel_config, n, /*target_panels_per_thread=*/1); @@ -124,15 +74,16 @@ Tensor pack_weights_cpu(const Tensor &weight_qvals, const Tensor &weight_scales, get_packed_weight_data_size(ukernel_config, n, k, group_size); Tensor packed_weights = torch::empty( {static_cast(packed_weight_data_size)}, torch::kInt8); - ukernel_config.packed_weights_header.write( - packed_weights.mutable_data_ptr()); - pack_weight_data_operator( - ukernel_config, pack_weight_tiling_params, - packed_weights.mutable_data_ptr() + - torchao::ops::PackedWeightsHeader::size(), - n, k, group_size, weight_qvals.const_data_ptr(), - weight_scales.const_data_ptr(), weight_zeros_ptr, - /*bias*/ nullptr); + packed_weights_header.write(packed_weights.mutable_data_ptr()); + + // TODO: support passing in bias in future + pack_weight_data_operator(ukernel_config, pack_weight_tiling_params, + packed_weights.mutable_data_ptr() + + torchao::ops::PackedWeightsHeader::size(), + n, k, group_size, + weight_qvals.const_data_ptr(), + weight_scales.const_data_ptr(), + weight_zeros_ptr, /*bias*/ nullptr); return packed_weights; } @@ -181,8 +132,10 @@ Tensor pack_weights_meta(const Tensor &weight_qvals, using namespace torchao::ops::linear_8bit_act_xbit_weight; - auto ukernel_config = get_ukernel_config(); + auto packed_weights_format = + select_packed_weights_format(); + auto ukernel_config = select_ukernel_config( + packed_weights_format); auto packed_weight_data_size = torchao::ops::PackedWeightsHeader::size() + @@ -278,18 +231,19 @@ linear_out_cpu(const Tensor &activations, const Tensor &packed_weights, torchao::ops::PackedWeightsHeader::read(packed_weights.const_data_ptr()); auto ukernel_config = - get_ukernel_config(header); + select_ukernel_config(header); auto linear_tiling_params = get_default_linear_tiling_params(ukernel_config, m, n, /*target_tiles_per_thread=*/5); + auto linear_scheduling_policy = LinearTileSchedulingPolicy::single_mc_parallel_nc; auto activation_data_buffer_size = get_activation_data_buffer_size( ukernel_config, linear_tiling_params, linear_scheduling_policy, m, k, group_size); + std::vector activation_data_buffer(activation_data_buffer_size); linear_operator(ukernel_config, linear_tiling_params, diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h deleted file mode 100644 index d86a429461..0000000000 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/packed_weights_header.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once -#include -#include - -namespace torchao::ops::linear_8bit_act_xbit_weight { - -inline torchao::ops::PackedWeightsHeader get_packed_weights_header_universal( - int weight_nbit, - bool has_weight_zeros, - bool has_bias, - int nr, - int kr, - int version = 1) { - return torchao::ops::PackedWeightsHeader( - torchao::ops::PackedWeightsFormat::linear_8bit_act_xbit_weight_universal, - {version, - weight_nbit, - has_weight_zeros, - has_bias, - nr, - kr, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0}); -} - -} // namespace torchao::ops::linear_8bit_act_xbit_weight diff --git a/torchao/experimental/ops/packed_weights_header.h b/torchao/experimental/ops/packed_weights_header.h index 7184da4b46..213ec34f7f 100644 --- a/torchao/experimental/ops/packed_weights_header.h +++ b/torchao/experimental/ops/packed_weights_header.h @@ -12,35 +12,36 @@ namespace torchao::ops { -enum class PackedWeightsFormat : uint32_t { +enum class PackedWeightsType : uint32_t { unknown = 0, linear_8bit_act_xbit_weight_universal = 1, - embedding_xbit_universal = 2 + embedding_xbit_universal = 2, + kleidi_ai = 3 }; class PackedWeightsHeader { public: using params_type = std::array; const static int magic = 6712; - PackedWeightsFormat format; + PackedWeightsType type; - // 14 bytes of format specific params + // 14 bytes of type specific params params_type params; PackedWeightsHeader( - PackedWeightsFormat format = PackedWeightsFormat::unknown, + PackedWeightsType type = PackedWeightsType::unknown, params_type params = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) - : format{format}, params{params} {} + : type{type}, params{params} {} inline static constexpr int size() { - static_assert(sizeof(magic) + sizeof(format) + sizeof(params) == 64); + static_assert(sizeof(magic) + sizeof(type) + sizeof(params) == 64); return 64; } inline void write(void* packed_weights) const { auto header = reinterpret_cast(packed_weights); header[0] = magic; - header[1] = static_cast(format); + header[1] = static_cast(type); for (int i = 0; i < params.size(); i++) { header[i + 2] = params[i]; } @@ -54,11 +55,11 @@ class PackedWeightsHeader { params[i] = header[i + 2]; } return PackedWeightsHeader( - static_cast(header[1]), params); + static_cast(header[1]), params); } bool operator==(const PackedWeightsHeader& other) const { - if (format != other.format) { + if (type != other.type) { return false; } for (int i = 0; i < params.size(); i++) { @@ -71,3 +72,16 @@ class PackedWeightsHeader { }; } // namespace torchao::ops + +namespace std { + template <> + struct hash { + std::size_t operator()(const torchao::ops::PackedWeightsHeader& f) const { + std::size_t hash = std::hash()(static_cast(f.type)); + for (int i = 0; i < f.params.size(); i++) { + hash ^= std::hash()(f.params[i]); + } + return hash; + }; +}; +} diff --git a/torchao/experimental/ops/tests/build_and_run_tests.sh b/torchao/experimental/ops/tests/build_and_run_tests.sh index 4070b9304f..cff7ca639a 100644 --- a/torchao/experimental/ops/tests/build_and_run_tests.sh +++ b/torchao/experimental/ops/tests/build_and_run_tests.sh @@ -9,6 +9,8 @@ target=${1:-"native"} SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests +export TORCH_DIR = $(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib() + '/torch/share/cmake/Torch')") + IS_ARM64=0 BUILD_ARM_I8MM=0 EXTRA_ARGS="" @@ -45,6 +47,7 @@ cmake \ -DCMAKE_BUILD_TYPE=Debug \ -DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \ -DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \ + -DTorch_DIR=${TORCH_DIR} \ -S . \ -B ${CMAKE_OUT} diff --git a/torchao/experimental/ops/tests/generate_tests.py b/torchao/experimental/ops/tests/generate_tests.py index 1710a90c49..160d8fa47a 100755 --- a/torchao/experimental/ops/tests/generate_tests.py +++ b/torchao/experimental/ops/tests/generate_tests.py @@ -51,6 +51,11 @@ def get_test_block(kernel): tests += add_test_string(kernel, 1, 2 * 13, 32, 32, True, False) tests += add_test_string(kernel, 1, 2 * 51, 32, 32, False, True) tests += add_test_string(kernel, 1, 2 * 111, 32, 32, False, False) + ## larger: n (odd) + tests += add_test_string(kernel, 1, 11, 32, 32, False, False) + tests += add_test_string(kernel, 1, 13, 32, 32, True, False) + tests += add_test_string(kernel, 1, 51, 32, 32, False, True) + tests += add_test_string(kernel, 1, 111, 32, 32, False, False) ## larger: k, g - must be multiple of 32 tests += add_test_string(kernel, 1, 2 * 7, 64, 32, False, False) tests += add_test_string(kernel, 1, 2 * 11, 128, 32, True, False) @@ -75,6 +80,11 @@ def get_test_block(kernel): tests += add_test_string(kernel, 17, 2 * 13, 32, 32, True, False) tests += add_test_string(kernel, 23, 2 * 51, 32, 32, False, True) tests += add_test_string(kernel, 41, 2 * 111, 32, 32, False, False) + ## larger: n (odd) + tests += add_test_string(kernel, 7, 11, 32, 32, False, False) + tests += add_test_string(kernel, 17, 13, 32, 32, True, False) + tests += add_test_string(kernel, 23, 51, 32, 32, False, True) + tests += add_test_string(kernel, 41, 111, 32, 32, False, False) ## larger: k, g - must be multiple of 32 tests += add_test_string(kernel, 19, 2 * 7, 64, 32, False, False) tests += add_test_string(kernel, 23, 2 * 11, 128, 32, True, False) diff --git a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp index bcf746e00e..295b93c3a4 100644 --- a/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp +++ b/torchao/experimental/ops/tests/test_linear_8bit_act_xbit_weight.cpp @@ -13,40 +13,36 @@ #include #if defined(TORCHAO_ENABLE_KLEIDI) -#include -#include -#if defined(TORCHAO_ENABLE_ARM_I8MM) -#include -#include -#endif // TORCHAO_ENABLE_ARM_I8MM +#include #endif // TORCHAO_ENABLE_KLEIDI const float kTol = 1.0e-5; using namespace torchao::ops::linear_8bit_act_xbit_weight; +using namespace torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p; template UKernelConfig get_ukernel_config() { - UKernelConfig config; - - namespace ukernel = torchao::kernels::cpu::aarch64::linear:: + namespace kernel = torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; - config.mr = 1; - config.nr = 8; - config.activation_data_size_fn = - &ukernel::activation_data_size; - config.preferred_activation_data_alignment = 16; // size of neon register - config.prepare_activation_data_fn = - &ukernel::prepare_activation_data; - config.weight_data_size_fn = - &ukernel::weight_data_size; - config.preferred_weight_data_alignment = 16; // size of neon register - config.prepare_weight_data_fn = - &ukernel::prepare_weight_data; - config.kernel_fn = - &ukernel::kernel; - - return config; + return UKernelConfig{ + /*preferred_alignment*/ 16, + /*nr*/ 8, + /*weight_packing_config*/ + {/*weight_data_size_fn*/ + &kernel::weight_data_size, + /*prepare_weight_data_fn*/ + &kernel::prepare_weight_data}, + /*linear_configs*/ + {{{/*mr*/ 1, + /*activation_data_size_fn*/ + &kernel::activation_data_size, + /*prepare_activation_data_fn*/ + &kernel::prepare_activation_data, + /*kernel*/ + &kernel::kernel}}}}; } template +UKernelConfig get_ukernel_config_kleidi() { + namespace op = torchao::kernels::cpu::aarch64::kleidi:: + kai_matmul_clamp_f32_qai8dxp_qsi4c32p; + auto uk = kernel_struct::get_ukernel(); + assert(m_step == uk.get_m_step()); + assert(mr == uk.get_mr()); + assert(n_step == uk.get_n_step()); + assert(nr == uk.get_nr()); + assert(kr == uk.get_kr()); + assert(sr == uk.get_sr()); + return UKernelConfig{ + op::get_preferred_alignement(), + n_step, + {/*weight_data_size_fn*/ &op::weight_data_size, + /*prepare_weight_data_fn*/ &op::prepare_weight_data}, + {{{m_step, &op::activation_data_size, + &op::prepare_activation_data, &kernel_struct::kernel}}}}; +} template UKernelConfig get_ukernel_config_kleidi() { - UKernelConfig config; #if defined(TORCHAO_ENABLE_ARM_I8MM) if constexpr (kernel_id == i8mm_4x8x32) { - KAI_GEN_UKERNEL( - torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_4x8x32); - return config; + constexpr int m_step = 4; + constexpr int mr = 4; + constexpr int n_step = 8; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + return get_ukernel_config_kleidi< + matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, m_step, mr, + n_step, nr, kr, sr>(); } if constexpr (kernel_id == i8mm_8x4x32) { - KAI_GEN_UKERNEL( - torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_i8mm_8x4x32); - return config; + constexpr int m_step = 8; + constexpr int mr = 8; + constexpr int n_step = 4; + constexpr int nr = 4; + constexpr int kr = 16; + constexpr int sr = 2; + return get_ukernel_config_kleidi< + matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm, m_step, mr, + n_step, nr, kr, sr>(); } #endif // TORCHAO_ENABLE_ARM_I8MM if constexpr (kernel_id == dotprod_1x8x32) { - KAI_GEN_UKERNEL( - torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x8x32); - return config; + constexpr int m_step = 1; + constexpr int mr = 1; + constexpr int n_step = 8; + constexpr int nr = 8; + constexpr int kr = 16; + constexpr int sr = 2; + return get_ukernel_config_kleidi< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, m_step, mr, + n_step, nr, kr, sr>(); } - KAI_GEN_UKERNEL( - torchao::kernels::cpu::aarch64::kleidi:: - kai_matmul_clamp_f32_qai8dxp_qsi4c32p::neon_dotprod_1x4x32); - return config; + if constexpr (kernel_id == dotprod_1x4x32) { + constexpr int m_step = 1; + constexpr int mr = 1; + constexpr int n_step = 4; + constexpr int nr = 4; + constexpr int kr = 16; + constexpr int sr = 2; + return get_ukernel_config_kleidi< + matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, m_step, mr, + n_step, nr, kr, sr>(); + } + throw std::runtime_error("Unsupported kernel_id"); } #endif // TORCHAO_ENABLE_KLEIDI @@ -253,7 +278,6 @@ TEST(test_linear_8bit_act_xbit_weight, GroupSizeNotDivisibleBy16) { std::runtime_error); } -// begin /* Generated by generate_tests.py */ /* Do not modify */ @@ -340,6 +364,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn222xk32xg32) { /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m1xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m1xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -494,6 +552,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m41xn222xk32xg32) { /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m7xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/7, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m17xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/17, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x4x32_m23xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/23, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m41xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/41, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x4x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -610,6 +702,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn222xk32xg32) { /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m1xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m1xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -764,6 +890,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m41xn222xk32xg32) { /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m7xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/7, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m17xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/17, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_dotprod_1x8x32_m23xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/23, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m41xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/41, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_dotprod_1x8x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -878,6 +1038,39 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn222xk32xg32) { /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m1xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -1029,6 +1222,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m41xn222xk32xg32) { /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m7xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/7, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m17xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/17, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_4x8x32_m23xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/23, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m41xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/41, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_4x8x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -1144,6 +1371,39 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn222xk32xg32) { /*m=*/1, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m1xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/1, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m1xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< @@ -1295,6 +1555,40 @@ TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m41xn222xk32xg32) { /*m=*/41, /*n=*/222, /*k=*/32, /*group_size=*/32, &ukernel_config); } +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m7xn11xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/7, /*n=*/11, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m17xn13xk32xg32_bias) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, true /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/17, /*n=*/13, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, + Kleidi_i8mm_8x4x32_m23xn51xk32xg32_clamp) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + true /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/23, /*n=*/51, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + +TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m41xn111xk32xg32) { + UKernelConfig ukernel_config = get_ukernel_config_kleidi(); + test_linear_8bit_act_xbit_weight< + 4 /*weight_nbit*/, false /*has_weight_zeros*/, false /*has_bias*/, + false /*has_clamp*/, true /*has_kleidi*/>( + /*m=*/41, /*n=*/111, /*k=*/32, /*group_size=*/32, &ukernel_config); +} + TEST(test_linear_8bit_act_xbit_weight, Kleidi_i8mm_8x4x32_m19xn14xk64xg32) { UKernelConfig ukernel_config = get_ukernel_config_kleidi(); test_linear_8bit_act_xbit_weight< From 0293bcdd596fc28c61706cccbebb994956e865cc Mon Sep 17 00:00:00 2001 From: "Jane (Yuan) Xu" <31798555+janeyx99@users.noreply.github.com> Date: Thu, 20 Feb 2025 16:21:08 -0500 Subject: [PATCH 099/115] Remove duplicate, confusing conditional in setup.py (#1748) let `get_extensions` handle the logic instead --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 357e0e491f..ee3ebbf453 100644 --- a/setup.py +++ b/setup.py @@ -312,7 +312,7 @@ def get_extensions(): package_data={ "torchao.kernel.configs": ["*.pkl"], }, - ext_modules=get_extensions() if use_cpp != "0" else None, + ext_modules=get_extensions(), extras_require={"dev": read_requirements("dev-requirements.txt")}, description="Package for applying ao techniques to GPU models", long_description=open("README.md").read(), From 6bab4dbb26cd571e8d46d8169737b2c61484d254 Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Thu, 20 Feb 2025 14:08:41 -0800 Subject: [PATCH 100/115] SAM2: Use torch.export for VOS (#1708) --- .../sam2_amg_server/compile_export_utils.py | 1 - examples/sam2_amg_server/generate_data.py | 2 +- .../sam2_vos_example/compile_export_utils.py | 271 +++++++++++++++++ examples/sam2_vos_example/requirements.txt | 2 + examples/sam2_vos_example/video_profile.py | 283 +++++------------- torchao/_models/sam2/modeling/sam2_base.py | 14 +- 6 files changed, 361 insertions(+), 212 deletions(-) create mode 100644 examples/sam2_vos_example/compile_export_utils.py create mode 100644 examples/sam2_vos_example/requirements.txt diff --git a/examples/sam2_amg_server/compile_export_utils.py b/examples/sam2_amg_server/compile_export_utils.py index 5903f4905e..d1c6fc06fa 100644 --- a/examples/sam2_amg_server/compile_export_utils.py +++ b/examples/sam2_amg_server/compile_export_utils.py @@ -16,7 +16,6 @@ TASK_TYPES = ["amg", "sps", "mps"] -# NOTE: We have to declare a separate class, because torch.export demands it. # We build this explicitly for the sole purpose of exporting _predict_masks # We made sure _predict_masks is fullgraph=True compileable so it can be exported # We must be sure to export using example args that are big enough and past diff --git a/examples/sam2_amg_server/generate_data.py b/examples/sam2_amg_server/generate_data.py index 311a3825ec..50eeccb912 100644 --- a/examples/sam2_amg_server/generate_data.py +++ b/examples/sam2_amg_server/generate_data.py @@ -551,7 +551,7 @@ def main( sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle" ) if export_model != "": - if not Path(output_folder).is_dir(): + if not Path(export_model).is_dir(): raise ValueError(f"Expected {export_model} to be a directory.") print(f"Exporting model to {export_model}.") from compile_export_utils import export_model as export_model_fn diff --git a/examples/sam2_vos_example/compile_export_utils.py b/examples/sam2_vos_example/compile_export_utils.py new file mode 100644 index 0000000000..7d1b3eddf3 --- /dev/null +++ b/examples/sam2_vos_example/compile_export_utils.py @@ -0,0 +1,271 @@ +import time +from pathlib import Path +from typing import Optional + +import torch + +from torchao._models.sam2.sam2_video_predictor import SAM2VideoPredictor + +# Tools used to avoid compilation cold start and dynamo cache lookups +# We take the compiled model and export it using the largest +# inputs possible (to avoid recompilations). +# We track the largest size and fail if we size something larger +# We export every compile-able subregion after wrapping it into +# a class to make export happy. + +TASK_TYPES = ["amg", "sps", "mps"] + + +class SAM2VideoPredictor_forward_sam_heads(torch.nn.Module): + def __init__( + self, + predictor: Optional[SAM2VideoPredictor], + batch_size=1, + aoti_compiled_model=None, + furious=False, + ): + super().__init__() + self.predictor = predictor + self.batch_size = batch_size + self.aoti_compiled_model = aoti_compiled_model + self.furious = furious + + def forward( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + assert mask_inputs is None + assert multimask_output + if self.predictor is None: + assert self.aoti_compiled_model is not None + return self.aoti_compiled_model( + backbone_features=backbone_features, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + return self.predictor._forward_sam_heads( + backbone_features=backbone_features, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + +def aot_compile( + model_directory, + name, + fn, + sample_args, + sample_kwargs=None, + options=None, + overwrite=False, +): + path = Path(model_directory) / Path(f"{name}.pt2") + if path.exists() and not overwrite: + raise ValueError(f"{path} already exists and overwrite is {overwrite}") + print(f"Saving at {path=}") + if options is None: + options = { + "max_autotune": True, + "triton.cudagraphs": True, + } + + from torch.export import export_for_inference + + exported = export_for_inference(fn, sample_args, sample_kwargs) + output_path = torch._inductor.aoti_compile_and_package( + exported, + package_path=str(path), + inductor_configs=options, + ) + return output_path + + +def aot_load(path): + return torch._export.aot_load(path, "cuda") + + +class FunctionModel(torch.nn.Module): + def __init__(self, module, fn_name): + super().__init__() + self.module = module + self.fn_name = fn_name + + def forward(self, *args): + return getattr(self.module, self.fn_name)(*args) + + +def export_model( + predictor, + model_directory, + furious=False, + batch_size=1, + overwrite=False, +): + if furious: + set_furious(predictor) + + example_input = torch.empty(batch_size, 3, 1024, 1024) + # example_input = example_input.to(predictor._image_dtype) + example_input = example_input.to(torch.bfloat16) + # example_input = (example_input.to(predictor.device),) + example_input = (example_input.to("cuda:0"),) + aot_compile( + model_directory, + "sam2_image_encoder_trunk", + predictor.image_encoder.trunk, + example_input, + overwrite=overwrite, + ) + + example_input_args = () + example_input_kwargs = { + "backbone_features": torch.randn( + batch_size, 256, 64, 64, dtype=torch.float32, device="cuda" + ), + # "point_inputs": { + # "point_coords": torch.ones(batch_size, 1, 2, dtype=torch.float32, device="cuda"), + # "point_labels": torch.ones(batch_size, 1, dtype=torch.int32, device="cuda"), + # }, + "point_inputs": None, + "mask_inputs": None, + "high_res_features": [ + torch.randn( + batch_size, + 32, + 256, + 256, + dtype=torch.bfloat16, + device="cuda", + ), + torch.randn( + batch_size, + 64, + 128, + 128, + dtype=torch.bfloat16, + device="cuda", + ), + ], + "multimask_output": True, + } + sam2_video_forward_sam_heads = SAM2VideoPredictor_forward_sam_heads( + predictor, + batch_size=batch_size, + furious=False, + ) + aot_compile( + model_directory, + "sam2_video_forward_sam_heads", + sam2_video_forward_sam_heads, + example_input_args, + sample_kwargs=example_input_kwargs, + overwrite=overwrite, + ) + + return predictor + + +class LoadedModel(torch.nn.Module): + def __init__(self, aoti_compiled_model): + super().__init__() + self.aoti_compiled_model = aoti_compiled_model + + def forward(self, *args, **kwargs): + return self.aoti_compiled_model(*args, **kwargs) + + +class LoadedDecoder(torch.nn.Module): + def __init__(self, aoti_compiled_model, other): + super().__init__() + self.aoti_compiled_model = aoti_compiled_model + self.other = other + + def forward(self, *args): + return self.aoti_compiled_model(*args) + + def get_dense_pe(self, *args, **kwargs) -> torch.Tensor: + return self.other.get_dense_pe(*args, **kwargs) + + +def load_exported_model( + predictor, + model_directory, + furious=False, + batch_size=1, +): + if furious: + set_furious(predictor) + t0 = time.time() + path = Path(model_directory) / Path("sam2_image_encoder_trunk.pt2") + assert path.exists(), f"Expected {path} to exist" + print(f"Start load from {path}") + pkg = torch._inductor.aoti_load_package(str(path)) + pkg_m = LoadedModel(pkg) + predictor.image_encoder.trunk = pkg_m + + path = Path(model_directory) / Path("sam2_video_forward_sam_heads.pt2") + assert path.exists(), f"Expected {path} to exist" + print(f"Start load from {path}") + pkg = torch._inductor.aoti_load_package(str(path)) + pkg_m = SAM2VideoPredictor_forward_sam_heads( + None, + batch_size=batch_size, + aoti_compiled_model=pkg, + furious=furious, + ) + predictor._forward_sam_heads = pkg_m.forward + + print(f"End load image encoder and _forward_sam_heads. Took {time.time() - t0}s") + return predictor + + +def set_fast(predictor, loaded_exported_model=False): + if not loaded_exported_model: + predictor.image_encoder.trunk.forward = torch.compile( + predictor.image_encoder.trunk.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + if not loaded_exported_model: + predictor._forward_sam_heads = torch.compile( + predictor._forward_sam_heads, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + predictor.memory_attention = torch.compile( + predictor.memory_attention, + mode="max-autotune", + fullgraph=True, + dynamic=True, + ) + predictor.memory_encoder.forward = torch.compile( + predictor.memory_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + +def set_furious(mask_generator): + mask_generator.predictor.model.image_encoder = ( + mask_generator.predictor.model.image_encoder.to(torch.float16) + ) + # NOTE: Not baseline feature + mask_generator.predictor._image_dtype = torch.float16 + mask_generator.predictor._transforms_device = mask_generator.predictor.device + torch.set_float32_matmul_precision("high") + mask_generator.predictor.model.sam_mask_decoder = ( + mask_generator.predictor.model.sam_mask_decoder.to(torch.float16) + ) + # NOTE: Not baseline feature + mask_generator.predictor.model.sam_mask_decoder._src_dtype = torch.float16 diff --git a/examples/sam2_vos_example/requirements.txt b/examples/sam2_vos_example/requirements.txt new file mode 100644 index 0000000000..cacdf09b2c --- /dev/null +++ b/examples/sam2_vos_example/requirements.txt @@ -0,0 +1,2 @@ +requests +fire diff --git a/examples/sam2_vos_example/video_profile.py b/examples/sam2_vos_example/video_profile.py index 4a7b830d6b..8ee9151cc4 100644 --- a/examples/sam2_vos_example/video_profile.py +++ b/examples/sam2_vos_example/video_profile.py @@ -1,9 +1,9 @@ -import argparse import os import time from datetime import datetime from pathlib import Path +import fire import numpy as np import requests import torch @@ -43,11 +43,11 @@ def download_file(url, download_dir): response = requests.get(url, stream=True) response.raise_for_status() # Raise an error for bad responses # Write the file to the specified directory - print(f"Downloading '{file_name}' to '{download_dir}'") + timestamped_print(f"Downloading '{file_name}' to '{download_dir}'") with open(file_path, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) - print(f"Downloaded '{file_name}' to '{download_dir}'") + timestamped_print(f"Downloaded '{file_name}' to '{download_dir}'") def model_type_to_paths(checkpoint_path, model_type): @@ -57,7 +57,7 @@ def model_type_to_paths(checkpoint_path, model_type): ) sam2_checkpoint = Path(checkpoint_path) / Path(MODEL_TYPES_TO_MODEL[model_type]) if not sam2_checkpoint.exists(): - print( + timestamped_print( f"Can't find checkpoint {sam2_checkpoint} in folder {checkpoint_path}. Downloading." ) download_file(MODEL_TYPES_TO_URL[model_type], checkpoint_path) @@ -103,12 +103,12 @@ def reset(self): def print_all_timings(self, warmup: int = 5): if not self.elapsed_times: - print("No timings recorded.") + timestamped_print("No timings recorded.") return - print("Average timings for all sections:") + timestamped_print("Average timings for all sections:") for section_name in self.elapsed_times: average_time = self.get_average_time(section_name, warmup) - print(f"{section_name}, {average_time*1000.0:.6f}") + timestamped_print(f"{section_name}, {average_time*1000.0:.6f}") global_timer = CodeTimer() @@ -121,7 +121,7 @@ def max_memory_allocated(): 100 * (max_memory_allocated_bytes / total_memory) ) max_memory_allocated_bytes = max_memory_allocated_bytes >> 20 - print( + timestamped_print( f"max_memory_allocated_bytes: {max_memory_allocated_bytes}MiB or {max_memory_allocated_percentage}%" ) @@ -150,12 +150,12 @@ def synthesize_video_data( vy = np.random.choice([-1, 1]) * speed # TODO: If these frames exist, they will not be deleted in subsequent runs with less frames. - print(f"Generate {n_frames} frames under path {out_dir}") + timestamped_print(f"Generate {n_frames} frames under path {out_dir}") if not synthesize_overwrite and len(os.listdir(out_dir)) > 0: raise ValueError( f"Expected folder {out_dir} to be empty unless --synthesize-overwrite is specified." ) - # Generate 100 frames + # Generate n_frames for i in range(n_frames): # Create a new image with a black background img = Image.new("RGB", (width, height), (0, 0, 0)) @@ -192,7 +192,7 @@ def profiler_runner(path, fn, *args, **kwargs): ) as prof: result = fn(*args, **kwargs) prof.export_chrome_trace(path) - print(f"Exported trace to {path}") + timestamped_print(f"Exported trace to {path}") return result @@ -220,26 +220,36 @@ def main_loop( return num_output_frames -def run_test( +def timestamped_print(*args, **kwargs): + # Get the current timestamp + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") + # Prepend the timestamp to the original print arguments + print(f"[{timestamp}]: ", *args, **kwargs) + + +def main( checkpoint_path: str, model_type: str, - profile: bool, - video_dir: str, - radius: int, - seed: int, - speed: int, - width: int, - height: int, - n_frames: int, - use_compile: bool, - frame_batch_size: int, - batch_size: int, - synthesize: bool, - synthesize_overwrite: bool, - store_output: str, - compare_output: str, - print_all_timings: bool, - use_baseline: bool, + video_dir="/tmp/segment-anything-2/synth_video", + profile=None, + radius=50, + seed=42, + speed=20, + width=1024, + height=1024, + n_frames=200, + use_compile=False, + batch_size=1, + frame_batch_size=1, + synthesize=False, + synthesize_overwrite=False, + store_output="", + compare_output="", + print_all_timings=False, + use_baseline=False, + export_model="", + load_exported_model="", + furious=False, ): np.random.seed(seed) start_x = np.random.randint(radius, width - radius) @@ -281,10 +291,17 @@ def run_test( # hydra_overrides_extra=hydra_overrides_extra, ) predictor._frame_batch_size = frame_batch_size + predictor.image_encoder.trunk = predictor.image_encoder.trunk.to(torch.bfloat16) + from torchao._models.sam2.modeling.sam.transformer import RoPEAttention + + rope_attention_modules = [ + module for module in predictor.modules() if isinstance(module, RoPEAttention) + ] + for r in rope_attention_modules: + r.freqs_cis = r.compute_cis(end_x=64, end_y=64, device=device) inference_states = [] for i in range(batch_size): - print("i: ", i) inference_state = predictor.init_state( video_path=f"{video_dir}_{i}", async_loading_frames=False ) @@ -301,77 +318,54 @@ def run_test( else: inference_state = predictor.batch_inference_states(inference_states) - if use_compile: - print("Using torch.compile") - predictor.image_encoder.trunk.forward = torch.compile( - predictor.image_encoder.trunk.forward, - # mode="max-autotune-no-cudagraphs", - mode="max-autotune", - fullgraph=True, - dynamic=False, + if export_model != "": + if not Path(export_model).is_dir(): + raise ValueError(f"Expected {export_model} to be a directory.") + timestamped_print(f"Exporting model to {export_model}.") + from compile_export_utils import export_model as export_model_fn + + export_model_fn( + predictor, + export_model, + furious=furious, + batch_size=1, + overwrite=False, ) - predictor.sam_prompt_encoder.forward = torch.compile( - predictor.sam_prompt_encoder.forward, - # mode="max-autotune-no-cudagraphs", - mode="max-autotune", - fullgraph=True, - dynamic=False, - ) + if load_exported_model != "": + from compile_export_utils import load_exported_model as load_exported_model_fn - predictor.sam_mask_decoder.transformer = torch.compile( - predictor.sam_mask_decoder.transformer, - mode="max-autotune", - # mode="max-autotune-no-cudagraphs", - fullgraph=True, - dynamic=False, + load_exported_model_fn( + predictor, load_exported_model, furious=furious, batch_size=1 ) - predictor._forward_sam_heads = torch.compile( - predictor._forward_sam_heads, - mode="max-autotune", - # mode="max-autotune-no-cudagraphs", - fullgraph=True, - dynamic=False, - ) - - predictor.memory_attention = torch.compile( - predictor.memory_attention, - # mode="max-autotune", - # mode="max-autotune-no-cudagraphs", - fullgraph=True, - dynamic=True, - ) + if use_compile: + from compile_export_utils import set_fast - predictor.memory_encoder.forward = torch.compile( - predictor.memory_encoder.forward, - mode="max-autotune", - # mode="max-autotune-no-cudagraphs", - fullgraph=True, - dynamic=False, - ) + set_fast(predictor, (load_exported_model != "")) - print("\nWarm-up round and gather outputs.") + timestamped_print("Warm-up round and gather outputs.") global_timer.reset() result = main_loop( predictor=predictor, inference_state=inference_state, accumulate_result=True ) if store_output: - print(f"Writing results to {store_output}") + timestamped_print(f"Writing results to {store_output}") torch.save(result, store_output) if compare_output: - print(f"Comparing to results from {compare_output}") + timestamped_print(f"Comparing to results from {compare_output}") ref_result = torch.load(compare_output) torch.testing.assert_close(result, ref_result) - print("Passed comparison!") + timestamped_print("Passed comparison!") if print_all_timings: global_timer.print_all_timings() global_timer.reset() - print("\nProfile round.") if profile is None: + timestamped_print("Practice round") main_loop(predictor=predictor, inference_state=inference_state) else: + timestamped_print(f"Saving profile under {profile}") profiler_runner( profile, main_loop, @@ -381,7 +375,7 @@ def run_test( if print_all_timings: global_timer.print_all_timings() - print("\nFinal timing and memory usage round.") + timestamped_print("Final timing and memory usage round.") torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() global_timer.reset() @@ -390,7 +384,7 @@ def run_test( predictor=predictor, inference_state=inference_state, count_result=True ) t = time.time() - t0 - print( + timestamped_print( f"main_loop took {t}s for {num_output_frames} frames at {num_output_frames / t}fps" ) max_memory_allocated() @@ -399,131 +393,4 @@ def run_test( if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "checkpoint_path", - type=str, - help="Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints", - ) - parser.add_argument( - "model_type", - type=str, - help=f"Choose one of {list(MODEL_TYPES_TO_MODEL.keys())}", - ) - parser.add_argument( - "--video_dir", - type=str, - default="/tmp/segment-anything-2/synth_video", - help="Directory to store the synthetic video", - ) - parser.add_argument( - "--profile", - type=str, - dest="profile", - help="If specified stores profile at given path.", - ) - parser.add_argument( - "--radius", - type=int, - default=50, - help="Radius of the circle for synthetic video", - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="Seed for initial position and velocity", - ) - parser.add_argument( - "--speed", type=int, default=20, help="Speed of the circle for synthetic video" - ) - parser.add_argument( - "--width", type=int, default=1024, help="Width of the synthetic video" - ) - parser.add_argument( - "--height", type=int, default=1024, help="Height of the synthetic video" - ) - parser.add_argument( - "--n_frames", - type=int, - default=200, - help="Number of frames in the synthetic video", - ) - parser.add_argument( - "--use-compile", - action="store_true", - dest="use_compile", - help="Use torch.compile to speed things up. First iteration will be much slower.", - ) - parser.add_argument( - "--batch-size", - type=int, - default=1, - help="batch_size", - ) - parser.add_argument( - "--frame-batch-size", - type=int, - default=1, - help="frame_batch_size", - ) - parser.add_argument( - "--synthesize", - action="store_true", - dest="synthesize", - help="Synthesize data for the benchmark.", - ) - parser.add_argument( - "--synthesize-overwrite", - action="store_true", - dest="synthesize_overwrite", - help="Overwrite data if it already exists when synthesizing.", - ) - parser.add_argument( - "--store-output", - type=str, - default="", - help="Pass a .pt file to store outputs in.", - ) - parser.add_argument( - "--compare-output", - type=str, - default="", - help="Pass a .pt file to load for comparison.", - ) - parser.add_argument( - "--print-all-timings", - action="store_true", - dest="print_all_timings", - help="Use torch.compile to speed things up. First iteration will be much slower.", - ) - parser.add_argument( - "--use-baseline", - action="store_true", - dest="use_baseline", - help="Use sam2 package instead of torchao._models.sam2", - ) - - args = parser.parse_args() - - run_test( - args.checkpoint_path, - args.model_type, - profile=args.profile, - video_dir=args.video_dir, - radius=args.radius, - seed=args.seed, - speed=args.speed, - width=args.width, - height=args.height, - n_frames=args.n_frames, - use_compile=args.use_compile, - frame_batch_size=args.frame_batch_size, - batch_size=args.batch_size, - synthesize=args.synthesize, - synthesize_overwrite=args.synthesize_overwrite, - store_output=args.store_output, - compare_output=args.compare_output, - print_all_timings=args.print_all_timings, - use_baseline=args.use_baseline, - ) + fire.Fire(main) diff --git a/torchao/_models/sam2/modeling/sam2_base.py b/torchao/_models/sam2/modeling/sam2_base.py index 01da983efc..4c2a24a0ef 100644 --- a/torchao/_models/sam2/modeling/sam2_base.py +++ b/torchao/_models/sam2/modeling/sam2_base.py @@ -670,6 +670,10 @@ def _prepare_memory_conditioned_features( memory = torch.cat(to_cat_memory, dim=0) memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + current_vision_feats = [c.clone() for c in current_vision_feats] + current_vision_pos_embeds = [c.clone() for c in current_vision_pos_embeds] + memory = memory.clone() + memory_pos_embed = memory_pos_embed.clone() pix_feat_with_mem = self.memory_attention( curr=current_vision_feats, curr_pos=current_vision_pos_embeds, @@ -677,6 +681,7 @@ def _prepare_memory_conditioned_features( memory_pos=memory_pos_embed, num_obj_ptr_tokens=num_obj_ptr_tokens, ) + pix_feat_with_mem = pix_feat_with_mem.clone() # reshape the output (HW)BC => BCHW pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) return pix_feat_with_mem @@ -784,11 +789,16 @@ def _track_step( assert point_inputs is not None and mask_inputs is None mask_inputs = prev_sam_mask_logits multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + + assert mask_inputs is None + assert multimask_output + if point_inputs is not None: + point_inputs = {k: point_inputs[k].contiguous() for k in point_inputs} sam_outputs = self._forward_sam_heads( - backbone_features=pix_feat, + backbone_features=pix_feat.contiguous(), point_inputs=point_inputs, mask_inputs=mask_inputs, - high_res_features=high_res_features, + high_res_features=[h.contiguous() for h in high_res_features], multimask_output=multimask_output, ) From 1c76736189220c445d3f5921aa59be9319e464ac Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Thu, 20 Feb 2025 17:02:36 -0800 Subject: [PATCH 101/115] Fix ruff for torchao/float8/config.py (#1750) --- torchao/float8/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchao/float8/config.py b/torchao/float8/config.py index ab2d89a91f..fa03d55b11 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -148,7 +148,6 @@ class Float8GemmConfig: # Pre-made recipes for common configurations class Float8LinearRecipeName(enum.Enum): - # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel TENSORWISE = "tensorwise" @@ -385,7 +384,6 @@ def from_recipe_name( ) elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP: - # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) From dc0134e46bf44f5887d6e9e70b9a6a03e43fa30c Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Thu, 20 Feb 2025 22:32:07 -0600 Subject: [PATCH 102/115] Add ciflow/rocm to bot-created tags (#1749) --- .github/pytorch-probot.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 2b63be96e1..583be7c620 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -2,3 +2,4 @@ mergebot: True ciflow_push_tags: - ciflow/benchmark - ciflow/tutorials +- ciflow/rocm From e0f7148cfcaa2f0b5f0286ae0d83b4f77afd8106 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:33:25 -0800 Subject: [PATCH 103/115] Update to cutlass 3.8 tag (#1754) stack-info: PR: https://github.com/pytorch/ao/pull/1754, branch: drisspg/stack/37 --- third_party/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cutlass b/third_party/cutlass index e9627ce55b..afa1772203 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit e9627ce55b42fd2599f58cd4396da9380954def0 +Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 From 878ec7a8026da5fb237413f5a007c1e256da4df4 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 21 Feb 2025 15:47:03 -0500 Subject: [PATCH 104/115] Add linear bias support for QAT (#1755) **Summary:** Add linear bias support for QAT, which previously resulted in the following unintuitive error message: ``` RuntimeError: Boolean value of Tensor with more than one value is ambiguous ``` Note that we don't fake quantize the bias still. We just support applying QAT on linear modules with bias. **Test Plan:** python test/quantization/test_qat.py -k test_qat_linear_bias --- test/quantization/test_qat.py | 34 ++++++++++++++++++++++++++++++ torchao/quantization/qat/linear.py | 14 ++++++------ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 9aeaa53664..4d685169a1 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -133,6 +133,21 @@ def forward(self, x): return x +class ModelWithLinearBias(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(512, 256, bias=True) + self.linear2 = torch.nn.Linear(256, 512, bias=True) + + def example_inputs(self): + return (torch.randn(1, 512),) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + class TestQAT(unittest.TestCase): SEED = 123 @@ -1366,6 +1381,25 @@ def test_fake_quantizer_repr(self): self.assertTrue("PerGroup" in fake_quantizer_repr) self.assertTrue("MappingType.SYMMETRIC" in fake_quantizer_repr) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_qat_linear_bias(self): + """ + Test that QAT supports linear bias. + """ + m = ModelWithLinearBias() + activation_config = FakeQuantizeConfig( + torch.int8, "per_token", is_symmetric=False + ) + weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=32) + quantize_( + m, + intx_quantization_aware_training(activation_config, weight_config), + ) + example_inputs = m.example_inputs() + m(*example_inputs) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index fafda68d58..716634fe9d 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -75,9 +75,6 @@ def __init__( *args, **kwargs, ) - if bias: - raise NotImplementedError("bias not supported yet") - # initialize activation fake quantizer if activation_config is not None: self.activation_fake_quantizer = FakeQuantizer(activation_config) @@ -103,17 +100,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: w = self.weight_fake_quantizer(self.weight) else: w = self.weight - return F.linear(x, w) + return F.linear(x, w, self.bias) def to_linear(self) -> torch.nn.Linear: new_linear = torch.nn.Linear( - self.in_features, self.out_features, self.bias, device=self.weight.device + self.in_features, + self.out_features, + self.bias is not None, + device=self.weight.device, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to # copy the weights, and doing so will result in an error if self.weight.device != torch.device("meta"): new_linear.weight = self.weight + new_linear.bias = self.bias return new_linear @classmethod @@ -126,7 +127,7 @@ def from_linear( new_linear = FakeQuantizedLinear( mod.in_features, mod.out_features, - mod.bias, + mod.bias is not None, activation_config=activation_config, weight_config=weight_config, device=mod.weight.device, @@ -136,6 +137,7 @@ def from_linear( # copy the weights, and doing so will result in an error if mod.weight.device != torch.device("meta"): new_linear.weight = mod.weight + new_linear.bias = mod.bias return new_linear From ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb Mon Sep 17 00:00:00 2001 From: Peter Yeh Date: Fri, 21 Feb 2025 15:27:51 -0800 Subject: [PATCH 105/115] [Reland] ROCm CI (Infra + Skips) (#1581) This PR to skip the unit test failures for ROCm + infra changes to enable ROCm CI. **NOTE:** This PR aims to enable the ROCm CI testing for torchao _only for pushes to main branch_. The ROCm tests should start showing up here once this PR is merged: https://hud.pytorch.org/hud/pytorch/ao/main/1?per_page=50&name_filter=regression Torchao PRs can also trigger the ROCm CI runs using the `ciflow/rocm` PR label (https://github.com/pytorch/ao/pull/1749). Enabling ROCm CI testing on *all* torchao PRs will be done in a follow-up PR. This pull request introduces the `skip_if_rocm` decorator across various test files to skip tests that are not yet supported on ROCm. The changes ensure that tests are conditionally skipped if ROCm is detected, improving the test suite's compatibility with different environments. # Key changes include: ### Cherry-pick ROCm CI infra changes from #999 ### Configure workflow to trigger ROCm CI only for pushes to main branch, OR on PRs with the `ciflow/rocm` label ### Introduction of `skip_if_rocm` decorator: * Added `skip_if_rocm` import in multiple test files to conditionally skip tests not supported on ROCm. (`test/dtypes/test_affine_quantized.py`, `test/dtypes/test_floatx.py`, `test/float8/test_base.py`, `test/hqq/test_hqq_affine.py`, `test/integration/test_integration.py`, `test/kernel/test_galore_downproj.py`, `test/prototype/test_awq.py`, `test/prototype/test_low_bit_optim.py`, `test/prototype/test_splitk.py`, `test/quantization/test_galore_quant.py`, `test/quantization/test_marlin_qqq.py`, `test/sparsity/test_marlin.py`, `test/test_ops.py`, `test/test_s8s4_linear_cutlass.py`, `torchao/utils.py`) [[1]](diffhunk://#diff-31b1ffcd78674b79cc65749176354ea4743683070120034709c1da7a3eac31f6R24) [[2]](diffhunk://#diff-0e811fa3416cd87d9a25b4fb680890098c69aa33ca4db4d347d4a10cc41e0eb3L30-R30) [[3]](diffhunk://#diff-05925b4469eb63ab854cc9891f088f570fa3822cdaeb4de109e0b1b9ab5038a7R21) [[4]](diffhunk://#diff-a9708dc28f15bb9cf665417e6c66601f9e8e2f1f672d1858603b74fa879a3357R13) [[5]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R83) [[6]](diffhunk://#diff-4b0ddf8d1e85f4b4f1067f8d1d3e6b4d48785b3675c7202bf49bfbb1079d682fR14) [[7]](diffhunk://#diff-66249d5a8ed995b0a8e22c6354d6b270c5feeb982cb79a28f7c1b929700e89f4L8-R12) [[8]](diffhunk://#diff-244d33d1e8c30e765556011a4d3b76509f61433a346ba12ffc3115144e895aedR33) [[9]](diffhunk://#diff-2bcf3336ff64bfef786e6126813db46040b93628cab5faff3f0f5ed2cb077bf2L16-R24) [[10]](diffhunk://#diff-51ddab022797064be44ca38c87a56c6e87cd69444f4c6151a11b7f0141aef2b9R21) [[11]](diffhunk://#diff-133d8c7492ee2e7536328c8391545610750774e43d128d258380cb6787bb9e93L22-R22) [[12]](diffhunk://#diff-a58427e02fb5b05d26e03e8c2d216e5ae379d82084fd14bf77ea127b5505a43cL18-R18) [[13]](diffhunk://#diff-d183f2afc51d6a59bc70094e8f476d2468c45e415500f6eb60abad955e065156R22-R24) [[14]](diffhunk://#diff-85cc98d31eb8056e082ebdfbf2979aaa046ffc08bbacd4a65a31795b51998645R10-R12) [[15]](diffhunk://#diff-d2a11602a79e83305208472f1abe6a4106f02ce62a7f9524007181813863fcf6R10) ### Application of `skip_if_rocm` decorator: * Applied `@skip_if_rocm("ROCm development in progress")` to multiple test functions to skip them when running on ROCm. (`test/dtypes/test_affine_quantized.py`, `test/dtypes/test_floatx.py`, `test/float8/test_base.py`, `test/hqq/test_hqq_affine.py`, `test/integration/test_integration.py`, `test/kernel/test_galore_downproj.py`, `test/prototype/test_awq.py`, `test/prototype/test_low_bit_optim.py`, `test/prototype/test_splitk.py`, `test/quantization/test_galore_quant.py`, `test/quantization/test_marlin_qqq.py`, `test/sparsity/test_marlin.py`) [[1]](diffhunk://#diff-31b1ffcd78674b79cc65749176354ea4743683070120034709c1da7a3eac31f6R93) [[2]](diffhunk://#diff-31b1ffcd78674b79cc65749176354ea4743683070120034709c1da7a3eac31f6R173) [[3]](diffhunk://#diff-31b1ffcd78674b79cc65749176354ea4743683070120034709c1da7a3eac31f6R186) [[4]](diffhunk://#diff-0e811fa3416cd87d9a25b4fb680890098c69aa33ca4db4d347d4a10cc41e0eb3R111) [[5]](diffhunk://#diff-05925b4469eb63ab854cc9891f088f570fa3822cdaeb4de109e0b1b9ab5038a7R427) [[6]](diffhunk://#diff-a9708dc28f15bb9cf665417e6c66601f9e8e2f1f672d1858603b74fa879a3357R114) [[7]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R571) [[8]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R690) [[9]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R710) [[10]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R904) [[11]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R924) [[12]](diffhunk://#diff-4b0ddf8d1e85f4b4f1067f8d1d3e6b4d48785b3675c7202bf49bfbb1079d682fR33) [[13]](diffhunk://#diff-66249d5a8ed995b0a8e22c6354d6b270c5feeb982cb79a28f7c1b929700e89f4R120) [[14]](diffhunk://#diff-244d33d1e8c30e765556011a4d3b76509f61433a346ba12ffc3115144e895aedR116) [[15]](diffhunk://#diff-2bcf3336ff64bfef786e6126813db46040b93628cab5faff3f0f5ed2cb077bf2L16-R24) [[16]](diffhunk://#diff-51ddab022797064be44ca38c87a56c6e87cd69444f4c6151a11b7f0141aef2b9R86) [[17]](diffhunk://#diff-133d8c7492ee2e7536328c8391545610750774e43d128d258380cb6787bb9e93R48) [[18]](diffhunk://#diff-133d8c7492ee2e7536328c8391545610750774e43d128d258380cb6787bb9e93R70) [[19]](diffhunk://#diff-a58427e02fb5b05d26e03e8c2d216e5ae379d82084fd14bf77ea127b5505a43cR40) [[20]](diffhunk://#diff-a58427e02fb5b05d26e03e8c2d216e5ae379d82084fd14bf77ea127b5505a43cL51-R58) ### Module-level skips for ROCm: * Added module-level skips for ROCm in specific test files to skip all tests within the module if ROCm is detected. (`test/test_ops.py`, `test/test_s8s4_linear_cutlass.py`) [[1]](diffhunk://#diff-d183f2afc51d6a59bc70094e8f476d2468c45e415500f6eb60abad955e065156R22-R24) [[2]](diffhunk://#diff-85cc98d31eb8056e082ebdfbf2979aaa046ffc08bbacd4a65a31795b51998645R10-R12) --- .github/workflows/regression_test_rocm.yml | 49 +++++++++++++++++++ test/dtypes/test_affine_quantized.py | 4 ++ .../test_affine_quantized_tensor_parallel.py | 4 ++ test/dtypes/test_floatx.py | 3 +- test/dtypes/test_nf4.py | 3 ++ test/dtypes/test_uint4.py | 4 +- test/float8/test_base.py | 2 + test/float8/test_float8_utils.py | 3 +- test/float8/test_fsdp2/test_fsdp2.py | 3 ++ test/hqq/test_hqq_affine.py | 2 + test/integration/test_integration.py | 8 +++ test/kernel/test_fused_kernels.py | 3 ++ test/kernel/test_galore_downproj.py | 2 + test/prototype/test_awq.py | 7 ++- test/prototype/test_low_bit_optim.py | 7 +++ test/prototype/test_smoothquant.py | 3 ++ test/prototype/test_splitk.py | 4 +- test/quantization/test_galore_quant.py | 2 + test/quantization/test_marlin_qqq.py | 5 +- test/quantization/test_quant_api.py | 2 + test/sparsity/test_marlin.py | 5 +- test/test_ops.py | 3 ++ torchao/dtypes/uintx/marlin_qqq_tensor.py | 4 +- torchao/dtypes/uintx/marlin_sparse_layout.py | 4 +- torchao/utils.py | 30 +++++++++++- 25 files changed, 153 insertions(+), 13 deletions(-) create mode 100644 .github/workflows/regression_test_rocm.yml diff --git a/.github/workflows/regression_test_rocm.yml b/.github/workflows/regression_test_rocm.yml new file mode 100644 index 0000000000..9a9a6c0071 --- /dev/null +++ b/.github/workflows/regression_test_rocm.yml @@ -0,0 +1,49 @@ +name: Run Regression Tests on ROCm + +on: + push: + branches: + - main + tags: + - ciflow/rocm/* + +concurrency: + group: regression_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} + cancel-in-progress: true + +env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + +jobs: + test-nightly: + strategy: + fail-fast: false + matrix: + include: + - name: ROCM Nightly + runs-on: linux.rocm.gpu.torchao + torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/rocm6.3' + gpu-arch-type: "rocm" + gpu-arch-version: "6.3" + + permissions: + id-token: write + contents: read + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + timeout: 120 + no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }} + runner: ${{ matrix.runs-on }} + gpu-arch-type: ${{ matrix.gpu-arch-type }} + gpu-arch-version: ${{ matrix.gpu-arch-version }} + submodules: recursive + script: | + conda create -n venv python=3.9 -y + conda activate venv + python -m pip install --upgrade pip + pip install ${{ matrix.torch-spec }} + pip install -r dev-requirements.txt + pip install . + export CONDA=$(dirname $(dirname $(which conda))) + export LD_LIBRARY_PATH=$CONDA/lib/:$LD_LIBRARY_PATH + pytest test --verbose -s diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 112cab8684..67ce8df78f 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -25,6 +25,7 @@ TORCH_VERSION_AT_LEAST_2_6, is_fbcode, is_sm_at_least_89, + skip_if_rocm, ) is_cusparselt_available = ( @@ -104,6 +105,7 @@ def test_tensor_core_layout_transpose(self): "apply_quant", get_quantization_functions(is_cusparselt_available, True, "cuda", True), ) + @skip_if_rocm("ROCm enablement in progress") def test_weights_only(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") if isinstance(apply_quant, AOBaseConfig): @@ -196,6 +198,7 @@ def apply_uint6_weight_only_quant(linear): "apply_quant", get_quantization_functions(is_cusparselt_available, True) ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_print_quantized_module(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") if isinstance(apply_quant, AOBaseConfig): @@ -213,6 +216,7 @@ class TestAffineQuantizedBasic(TestCase): @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) + @skip_if_rocm("ROCm enablement in progress") def test_flatten_unflatten(self, device, dtype): if device == "cuda" and dtype == torch.bfloat16 and is_fbcode(): raise unittest.SkipTest("TODO: Failing for cuda + bfloat16 in fbcode") diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index 76b6b74a3d..b60f3251dc 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -1,5 +1,6 @@ import unittest +import pytest import torch from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.testing._internal import common_utils @@ -27,6 +28,9 @@ except ModuleNotFoundError: has_gemlite = False +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + class TestAffineQuantizedTensorParallel(DTensorTestBase): """Basic test case for tensor subclasses""" diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 8bb39b2cc8..f321d81b9e 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -27,7 +27,7 @@ fpx_weight_only, quantize_, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) _Floatx_DTYPES = [(3, 2), (2, 2)] @@ -109,6 +109,7 @@ def test_to_copy_device(self, ebits, mbits): @parametrize("bias", [False, True]) @parametrize("dtype", [torch.half, torch.bfloat16]) @unittest.skipIf(is_fbcode(), reason="broken in fbcode") + @skip_if_rocm("ROCm enablement in progress") def test_fpx_weight_only(self, ebits, mbits, bias, dtype): N, OC, IC = 4, 256, 64 device = "cuda" diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index caa1a6c7bd..a5190fb679 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -33,6 +33,7 @@ nf4_weight_only, to_nf4, ) +from torchao.utils import skip_if_rocm bnb_available = False @@ -111,6 +112,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47 @@ -133,6 +135,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype): @unittest.skipIf(not bnb_available, "Need bnb availble") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") @parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_nf4_bnb_linear(self, dtype: torch.dtype): """ diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index e148d68abb..9d0c4e82df 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -28,7 +28,7 @@ from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm def _apply_weight_only_uint4_quant(model): @@ -92,6 +92,7 @@ def test_basic_tensor_ops(self): # only test locally # print("x:", x[0]) + @skip_if_rocm("ROCm enablement in progress") def test_gpu_quant(self): for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: x = torch.randn(*x_shape) @@ -104,6 +105,7 @@ def test_gpu_quant(self): # make sure it runs opt(x) + @skip_if_rocm("ROCm enablement in progress") def test_pt2e_quant(self): from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( QuantizationConfig, diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 156c8abe87..350f0fb175 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -18,6 +18,7 @@ TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89, is_sm_at_least_90, + skip_if_rocm, ) if not TORCH_VERSION_AT_LEAST_2_5: @@ -426,6 +427,7 @@ def test_linear_from_config_params( @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize("linear_bias", [True, False]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skip_if_rocm("ROCm enablement in progress") def test_linear_from_recipe( self, recipe_name, diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index ca9f21dde1..218d3b8c1f 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -4,7 +4,7 @@ import torch from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) @@ -30,6 +30,7 @@ # ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]), ], ) +@skip_if_rocm("ROCm enablement in progress") def test_round_scale_down_to_power_of_2_valid_inputs( test_case: dict, ): diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index fbe5c9b508..0beb012406 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -43,6 +43,9 @@ if not is_sm_at_least_89(): pytest.skip("Unsupported CUDA device capability version", allow_module_level=True) +if torch.version.hip is not None: + pytest.skip("ROCm enablement in progress", allow_module_level=True) + class TestFloat8Common: def broadcast_module(self, module: nn.Module) -> None: diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index d18ff59f99..4ffe22cda8 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -11,6 +11,7 @@ ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, + skip_if_rocm, ) cuda_available = torch.cuda.is_available() @@ -109,6 +110,7 @@ def test_hqq_plain_5bit(self): ref_dot_product_error=0.000704, ) + @skip_if_rocm("ROCm enablement in progress") def test_hqq_plain_4bit(self): self._test_hqq( dtype=torch.uint4, diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 56bcaf17df..8327580748 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -85,6 +85,7 @@ benchmark_model, is_fbcode, is_sm_at_least_90, + skip_if_rocm, unwrap_tensor_subclass, ) @@ -95,6 +96,7 @@ except ModuleNotFoundError: has_gemlite = False + logger = logging.getLogger("INFO") torch.manual_seed(0) @@ -582,6 +584,7 @@ def test_per_token_linear_cpu(self): self._test_per_token_linear_impl("cpu", dtype) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_per_token_linear_cuda(self): for dtype in (torch.float32, torch.float16, torch.bfloat16): self._test_per_token_linear_impl("cuda", dtype) @@ -700,6 +703,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -719,6 +723,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -912,6 +917,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") @@ -931,6 +937,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") @@ -1102,6 +1109,7 @@ def test_gemlite_layout(self, device, dtype): @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") + @skip_if_rocm("ROCm enablement in progress") def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): if device == "cpu": self.skipTest(f"Temporarily skipping for {device}") diff --git a/test/kernel/test_fused_kernels.py b/test/kernel/test_fused_kernels.py index c5bf6e17f0..cad1f001ff 100644 --- a/test/kernel/test_fused_kernels.py +++ b/test/kernel/test_fused_kernels.py @@ -11,6 +11,8 @@ import torch from galore_test_utils import get_kernel, make_copy, make_data +from torchao.utils import skip_if_rocm + torch.manual_seed(0) MAX_DIFF_no_tf32 = 1e-5 MAX_DIFF_tf32 = 1e-3 @@ -104,6 +106,7 @@ def run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize("kernel, dtype, M, N, rank, allow_tf32", TEST_CONFIGS) +@skip_if_rocm("ROCm enablement in progress") def test_galore_fused_kernels(kernel, dtype, M, N, rank, allow_tf32): torch.backends.cuda.matmul.allow_tf32 = allow_tf32 diff --git a/test/kernel/test_galore_downproj.py b/test/kernel/test_galore_downproj.py index bab65fc2fb..2388f0be63 100644 --- a/test/kernel/test_galore_downproj.py +++ b/test/kernel/test_galore_downproj.py @@ -11,6 +11,7 @@ from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk from torchao.prototype.galore.kernels.matmul import triton_mm_launcher +from torchao.utils import skip_if_rocm torch.manual_seed(0) @@ -29,6 +30,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") @pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS) +@skip_if_rocm("ROCm enablement in progress") def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype): torch.backends.cuda.matmul.allow_tf32 = allow_tf32 MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 1b91983bc0..409518ae9a 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -5,7 +5,11 @@ import torch from torchao.quantization import quantize_ -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_5, + skip_if_rocm, +) if TORCH_VERSION_AT_LEAST_2_3: from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ @@ -113,6 +117,7 @@ def test_awq_loading(device, qdtype): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@skip_if_rocm("ROCm enablement in progress") def test_save_weights_only(): dataset_size = 100 l1, l2, l3 = 512, 256, 128 diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index d7d6fe7dc8..5ce3d08b81 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -30,6 +30,7 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, get_available_devices, + skip_if_rocm, ) try: @@ -42,6 +43,8 @@ except ImportError: lpmm = None +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) _DEVICES = get_available_devices() @@ -112,6 +115,7 @@ class TestOptim(TestCase): ) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) + @skip_if_rocm("ROCm enablement in progress") def test_optim_smoke(self, optim_name, dtype, device): if optim_name.endswith("Fp8") and device == "cuda": if not TORCH_VERSION_AT_LEAST_2_4: @@ -185,6 +189,7 @@ def test_subclass_slice(self, subclass, shape, device): not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA", ) + @skip_if_rocm("ROCm enablement in progress") @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): device = "cuda" @@ -413,6 +418,7 @@ def world_size(self) -> int: not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + @skip_if_rocm("ROCm enablement in progress") def test_fsdp2(self): optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit] if torch.cuda.get_device_capability() >= (8, 9): @@ -523,6 +529,7 @@ def _test_fsdp2(self, optim_cls): not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required." ) @skip_if_lt_x_gpu(_FSDP_WORLD_SIZE) + @skip_if_rocm("ROCm enablement in progress") def test_uneven_shard(self): in_dim = 512 out_dim = _FSDP_WORLD_SIZE * 16 + 1 diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index 02b41e8e32..d90990143c 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -20,6 +20,9 @@ TORCH_VERSION_AT_LEAST_2_5, ) +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + class ToyLinearModel(torch.nn.Module): def __init__(self, m=512, n=256, k=128): diff --git a/test/prototype/test_splitk.py b/test/prototype/test_splitk.py index 48793ba907..04fdd7cff2 100644 --- a/test/prototype/test_splitk.py +++ b/test/prototype/test_splitk.py @@ -13,13 +13,15 @@ except ImportError: triton_available = False -from torchao.utils import skip_if_compute_capability_less_than + +from torchao.utils import skip_if_compute_capability_less_than, skip_if_rocm @unittest.skipIf(not triton_available, "Triton is required but not available") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") class TestFP8Gemm(TestCase): @skip_if_compute_capability_less_than(9.0) + @skip_if_rocm("ROCm enablement in progress") def test_gemm_split_k(self): dtype = torch.float16 qdtype = torch.float8_e4m3fn diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 3eb9b0a2c5..277bf6a49f 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -18,6 +18,7 @@ triton_dequant_blockwise, triton_quantize_blockwise, ) +from torchao.utils import skip_if_rocm SEED = 0 torch.manual_seed(SEED) @@ -82,6 +83,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize): "dim1,dim2,dtype,signed,blocksize", TEST_CONFIGS, ) +@skip_if_rocm("ROCm enablement in progress") def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize): g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01 diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index 1fd60acb52..590c52bbde 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -18,9 +18,10 @@ MappingType, choose_qparams_and_quantize_affine_qqq, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm +@skip_if_rocm("ROCm enablement in progress") class TestMarlinQQQ(TestCase): def setUp(self): super().setUp() @@ -40,6 +41,7 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_marlin_qqq(self): output_ref = self.model(self.input) for group_size in [-1, 128]: @@ -61,6 +63,7 @@ def test_marlin_qqq(self): @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm development in progress") def test_marlin_qqq_compile(self): model_copy = copy.deepcopy(self.model) model_copy.forward = torch.compile(model_copy.forward, fullgraph=True) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index a53f47ac14..4e903f0a4b 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -55,6 +55,7 @@ TORCH_VERSION_AT_LEAST_2_6, is_sm_at_least_89, is_sm_at_least_90, + skip_if_rocm, unwrap_tensor_subclass, ) @@ -819,6 +820,7 @@ def test_int4wo_cpu(self, dtype, x_dim): uintx_weight_only(dtype=torch.uint4), ], ) + @skip_if_rocm("ROCm enablement in progress") def test_workflow_e2e_numerics(self, config): """ Simple test of e2e int4_weight_only workflow, comparing numerics diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index 4da7304a24..c8bdee5e2f 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -15,7 +15,7 @@ ) from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24 from torchao.sparsity.sparse_api import apply_fake_sparsity -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm class SparseMarlin24(TestCase): @@ -37,6 +37,7 @@ def setUp(self): ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_quant_sparse_marlin_layout_eager(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) @@ -48,13 +49,13 @@ def test_quant_sparse_marlin_layout_eager(self): # Sparse + quantized quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout())) sparse_result = self.model(self.input) - assert torch.allclose( dense_result, sparse_result, atol=3e-1 ), "Results are not close" @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @skip_if_rocm("ROCm enablement in progress") def test_quant_sparse_marlin_layout_compile(self): apply_fake_sparsity(self.model) model_copy = copy.deepcopy(self.model) diff --git a/test/test_ops.py b/test/test_ops.py index b3b160e85f..076ab9ab16 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -20,6 +20,9 @@ from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff +if torch.version.hip is not None: + pytest.skip("Skipping the test in ROCm", allow_module_level=True) + try: import torchao.ops except RuntimeError: diff --git a/torchao/dtypes/uintx/marlin_qqq_tensor.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py index 95175caacf..abf09cd2f9 100644 --- a/torchao/dtypes/uintx/marlin_qqq_tensor.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -183,7 +183,7 @@ def __tensor_unflatten__( def get_plain(self): from torchao.quantization.marlin_qqq import ( unpack_from_marlin_qqq, - ) # avoid circular import + ) int_data_expanded, s_group_expanded, s_channel_expanded = ( unpack_from_marlin_qqq( @@ -211,7 +211,7 @@ def from_plain( from torchao.quantization.marlin_qqq import ( const, pack_to_marlin_qqq, - ) # avoid circular import + ) assert isinstance(_layout, MarlinQQQLayout) diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index 22763eb0c2..01d4562b7f 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -206,7 +206,7 @@ def __tensor_unflatten__( def get_plain(self): from torchao.sparsity.marlin import ( unpack_from_marlin_24, - ) # avoid circular import + ) int_data_expanded, scales_expanded = unpack_from_marlin_24( self.int_data, @@ -231,7 +231,7 @@ def from_plain( from torchao.sparsity.marlin import ( const, pack_to_marlin_24, - ) # avoid circular import + ) assert isinstance(_layout, MarlinSparseLayout) diff --git a/torchao/utils.py b/torchao/utils.py index 13b59c2e81..dfc18b2265 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -7,6 +7,7 @@ from math import gcd from typing import Any, Callable, Tuple +import pytest import torch import torch.nn.utils.parametrize as parametrize @@ -161,6 +162,33 @@ def wrapper(*args, **kwargs): return decorator +def skip_if_rocm(message=None): + """Decorator to skip tests on ROCm platform with custom message. + + Args: + message (str, optional): Additional information about why the test is skipped. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if torch.version.hip is not None: + skip_message = "Skipping the test in ROCm" + if message: + skip_message += f": {message}" + pytest.skip(skip_message) + return func(*args, **kwargs) + + return wrapper + + # Handle both @skip_if_rocm and @skip_if_rocm() syntax + if callable(message): + func = message + message = None + return decorator(func) + return decorator + + def compute_max_diff(output: torch.Tensor, output_ref: torch.Tensor) -> torch.Tensor: return torch.mean(torch.abs(output - output_ref)) / torch.mean( torch.abs(output_ref) @@ -626,7 +654,7 @@ def _torch_version_at_least(min_version): def is_MI300(): if torch.cuda.is_available() and torch.version.hip: mxArchName = ["gfx940", "gfx941", "gfx942"] - archName = torch.cuda.get_device_properties().gcnArchName + archName = torch.cuda.get_device_properties(0).gcnArchName for arch in mxArchName: if arch in archName: return True From c72ebc65225cc4323fa48c8ad28ac1e4c5283a1e Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 21 Feb 2025 19:06:45 -0800 Subject: [PATCH 106/115] move decorators to testing/utils.py (#1761) * move decorators to testing/utils.py * add import * fix import * fix ruff formatting error * ruff fixes * ruff format * compute_capability test * update * update rest of tests * fix ruff --- test/dtypes/test_affine_quantized.py | 2 +- test/dtypes/test_floatx.py | 3 +- test/dtypes/test_nf4.py | 2 +- test/dtypes/test_uint4.py | 3 +- test/float8/test_base.py | 2 +- test/float8/test_float8_utils.py | 3 +- test/hqq/test_hqq_affine.py | 2 +- test/integration/test_integration.py | 2 +- test/kernel/test_fused_kernels.py | 2 +- test/kernel/test_galore_downproj.py | 2 +- test/prototype/test_awq.py | 2 +- test/prototype/test_low_bit_optim.py | 2 +- test/prototype/test_splitk.py | 2 +- test/quantization/test_galore_quant.py | 2 +- test/quantization/test_marlin_qqq.py | 3 +- test/quantization/test_quant_api.py | 2 +- test/sparsity/test_marlin.py | 3 +- torchao/testing/utils.py | 46 +++++++++++++++++++++++++- torchao/utils.py | 45 ------------------------- 19 files changed, 67 insertions(+), 63 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 67ce8df78f..6b3a447070 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -20,12 +20,12 @@ quantize_, ) from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, is_fbcode, is_sm_at_least_89, - skip_if_rocm, ) is_cusparselt_available = ( diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index f321d81b9e..0953e33b0f 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -27,7 +27,8 @@ fpx_weight_only, quantize_, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm +from torchao.testing.utils import skip_if_rocm +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) _Floatx_DTYPES = [(3, 2), (2, 2)] diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index a5190fb679..4ed90d06ca 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -33,7 +33,7 @@ nf4_weight_only, to_nf4, ) -from torchao.utils import skip_if_rocm +from torchao.testing.utils import skip_if_rocm bnb_available = False diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index 9d0c4e82df..cf4077a78c 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -28,7 +28,8 @@ from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm +from torchao.testing.utils import skip_if_rocm +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 def _apply_weight_only_uint4_quant(model): diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 350f0fb175..818b413a77 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -14,11 +14,11 @@ import torch import torch.nn as nn +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89, is_sm_at_least_90, - skip_if_rocm, ) if not TORCH_VERSION_AT_LEAST_2_5: diff --git a/test/float8/test_float8_utils.py b/test/float8/test_float8_utils.py index 218d3b8c1f..1a6a888246 100644 --- a/test/float8/test_float8_utils.py +++ b/test/float8/test_float8_utils.py @@ -4,7 +4,8 @@ import torch from torchao.float8.float8_utils import _round_scale_down_to_power_of_2 -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm +from torchao.testing.utils import skip_if_rocm +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 if not TORCH_VERSION_AT_LEAST_2_5: pytest.skip("Unsupported PyTorch version", allow_module_level=True) diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 4ffe22cda8..7bbd52db09 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -9,9 +9,9 @@ quantize_, uintx_weight_only, ) +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, - skip_if_rocm, ) cuda_available = torch.cuda.is_available() diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 8327580748..7fd96e4d97 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -76,6 +76,7 @@ from torchao.quantization.utils import ( compute_error as SQNR, ) +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -85,7 +86,6 @@ benchmark_model, is_fbcode, is_sm_at_least_90, - skip_if_rocm, unwrap_tensor_subclass, ) diff --git a/test/kernel/test_fused_kernels.py b/test/kernel/test_fused_kernels.py index cad1f001ff..9c5bc19aaf 100644 --- a/test/kernel/test_fused_kernels.py +++ b/test/kernel/test_fused_kernels.py @@ -11,7 +11,7 @@ import torch from galore_test_utils import get_kernel, make_copy, make_data -from torchao.utils import skip_if_rocm +from torchao.testing.utils import skip_if_rocm torch.manual_seed(0) MAX_DIFF_no_tf32 = 1e-5 diff --git a/test/kernel/test_galore_downproj.py b/test/kernel/test_galore_downproj.py index 2388f0be63..fc8b784a9f 100644 --- a/test/kernel/test_galore_downproj.py +++ b/test/kernel/test_galore_downproj.py @@ -11,7 +11,7 @@ from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk from torchao.prototype.galore.kernels.matmul import triton_mm_launcher -from torchao.utils import skip_if_rocm +from torchao.testing.utils import skip_if_rocm torch.manual_seed(0) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index 409518ae9a..1bfdf57aca 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -5,10 +5,10 @@ import torch from torchao.quantization import quantize_ +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, - skip_if_rocm, ) if TORCH_VERSION_AT_LEAST_2_3: diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 5ce3d08b81..453210abda 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -26,11 +26,11 @@ from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8 +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, get_available_devices, - skip_if_rocm, ) try: diff --git a/test/prototype/test_splitk.py b/test/prototype/test_splitk.py index 04fdd7cff2..37aeac1334 100644 --- a/test/prototype/test_splitk.py +++ b/test/prototype/test_splitk.py @@ -14,7 +14,7 @@ triton_available = False -from torchao.utils import skip_if_compute_capability_less_than, skip_if_rocm +from torchao.testing.utils import skip_if_compute_capability_less_than, skip_if_rocm @unittest.skipIf(not triton_available, "Triton is required but not available") diff --git a/test/quantization/test_galore_quant.py b/test/quantization/test_galore_quant.py index 277bf6a49f..6b26b948f5 100644 --- a/test/quantization/test_galore_quant.py +++ b/test/quantization/test_galore_quant.py @@ -18,7 +18,7 @@ triton_dequant_blockwise, triton_quantize_blockwise, ) -from torchao.utils import skip_if_rocm +from torchao.testing.utils import skip_if_rocm SEED = 0 torch.manual_seed(SEED) diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py index 590c52bbde..f8581b1307 100644 --- a/test/quantization/test_marlin_qqq.py +++ b/test/quantization/test_marlin_qqq.py @@ -18,7 +18,8 @@ MappingType, choose_qparams_and_quantize_affine_qqq, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm +from torchao.testing.utils import skip_if_rocm +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @skip_if_rocm("ROCm enablement in progress") diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 4e903f0a4b..4af429940f 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -48,6 +48,7 @@ Int8WeightOnlyQuantizedLinearWeight, ) from torchao.quantization.utils import compute_error +from torchao.testing.utils import skip_if_rocm from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, @@ -55,7 +56,6 @@ TORCH_VERSION_AT_LEAST_2_6, is_sm_at_least_89, is_sm_at_least_90, - skip_if_rocm, unwrap_tensor_subclass, ) diff --git a/test/sparsity/test_marlin.py b/test/sparsity/test_marlin.py index c8bdee5e2f..dc4489f05e 100644 --- a/test/sparsity/test_marlin.py +++ b/test/sparsity/test_marlin.py @@ -15,7 +15,8 @@ ) from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24 from torchao.sparsity.sparse_api import apply_fake_sparsity -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm +from torchao.testing.utils import skip_if_rocm +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 class SparseMarlin24(TestCase): diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index d88241783f..02d151cdb4 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -14,7 +14,7 @@ from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx from torchao.quantization import int8_weight_only, quantize_ from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, get_compute_capability """ How to use: @@ -41,6 +41,50 @@ class MyTestCase(TorchAOBasicTestCase): """ +def skip_if_compute_capability_less_than(min_capability): + import unittest + + def decorator(test_func): + def wrapper(*args, **kwargs): + if get_compute_capability() < min_capability: + raise unittest.SkipTest( + f"Compute capability is less than {min_capability}" + ) + return test_func(*args, **kwargs) + + return wrapper + + return decorator + + +def skip_if_rocm(message=None): + """Decorator to skip tests on ROCm platform with custom message. + + Args: + message (str, optional): Additional information about why the test is skipped. + """ + import pytest + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if torch.version.hip is not None: + skip_message = "Skipping the test in ROCm" + if message: + skip_message += f": {message}" + pytest.skip(skip_message) + return func(*args, **kwargs) + + return wrapper + + # Handle both @skip_if_rocm and @skip_if_rocm() syntax + if callable(message): + func = message + message = None + return decorator(func) + return decorator + + # copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389 def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902 for name, value in my_cls.__dict__.items(): diff --git a/torchao/utils.py b/torchao/utils.py index dfc18b2265..2a67f8a9c9 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -7,7 +7,6 @@ from math import gcd from typing import Any, Callable, Tuple -import pytest import torch import torch.nn.utils.parametrize as parametrize @@ -16,7 +15,6 @@ "profiler_runner", "get_available_devices", "get_compute_capability", - "skip_if_compute_capability_less_than", "benchmark_torch_function_in_microseconds", "find_multiple", "_register_custom_op", @@ -146,49 +144,6 @@ def get_compute_capability(): return 0.0 -def skip_if_compute_capability_less_than(min_capability): - import unittest - - def decorator(test_func): - def wrapper(*args, **kwargs): - if get_compute_capability() < min_capability: - raise unittest.SkipTest( - f"Compute capability is less than {min_capability}" - ) - return test_func(*args, **kwargs) - - return wrapper - - return decorator - - -def skip_if_rocm(message=None): - """Decorator to skip tests on ROCm platform with custom message. - - Args: - message (str, optional): Additional information about why the test is skipped. - """ - - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - if torch.version.hip is not None: - skip_message = "Skipping the test in ROCm" - if message: - skip_message += f": {message}" - pytest.skip(skip_message) - return func(*args, **kwargs) - - return wrapper - - # Handle both @skip_if_rocm and @skip_if_rocm() syntax - if callable(message): - func = message - message = None - return decorator(func) - return decorator - - def compute_max_diff(output: torch.Tensor, output_ref: torch.Tensor) -> torch.Tensor: return torch.mean(torch.abs(output - output_ref)) / torch.mean( torch.abs(output_ref) From 25ddb779c00a70c17f40253bee8901afa650b1fd Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Fri, 21 Feb 2025 19:12:11 -0800 Subject: [PATCH 107/115] Allow for scales to be in new e8m0 dtype (#1742) stack-info: PR: https://github.com/pytorch/ao/pull/1742, branch: drisspg/stack/36 --- torchao/ops.py | 44 +++++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/torchao/ops.py b/torchao/ops.py index bba2a054fc..a3aee761b9 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,3 +1,5 @@ +import functools + import torch from torch import Tensor @@ -606,6 +608,27 @@ def _( return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) +@functools.lru_cache() +def _get_dtypes(): + """TODO: when e8m0 is hardened and major release lets remove uint8 support""" + if hasattr(torch, "float8_e8m0fnu"): + return (torch.uint8, torch.float8_e8m0fnu) + return (torch.uint8,) + + +def _check_scale_dtypes(A_scale, B_scale): + allowed_dtypes = _get_dtypes() + + torch._check( + A_scale.dtype in allowed_dtypes, + lambda: f"A_scale tensor must be uint8 or float8_e8m0fnu, got {A_scale.dtype}", + ) + torch._check( + B_scale.dtype in allowed_dtypes, + lambda: f"B_scale tensor must be uint8 or float8_e8m0fnu, got {B_scale.dtype}", + ) + + def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): """Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor. @@ -625,25 +648,7 @@ def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): MXN bf16 Tensor """ - torch._check( - A.dtype == torch.float8_e4m3fn, - lambda: f"Input tensor A must be float8_e4m3fn, got {A.dtype}", - ) - torch._check( - B.dtype == torch.float8_e4m3fn, - lambda: f"Input tensor B must be float8_e4m3fn, got {B.dtype}", - ) - - # TODO - Once e8m0 dtype is added to core udpate - # Check scale tensors are uint8 - torch._check( - A_scale.dtype == torch.uint8, - lambda: f"A_scale tensor must be uint8, got {A_scale.dtype}", - ) - torch._check( - B_scale.dtype == torch.uint8, - lambda: f"B_scale tensor must be uint8, got {B_scale.dtype}", - ) + _check_scale_dtypes(A_scale, B_scale) return torch.ops.torchao.mx_fp8_bf16.default(A, B, A_scale, B_scale) @@ -674,6 +679,7 @@ def mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): MXN bf16 Tensor """ + _check_scale_dtypes(A_scale, B_scale) return torch.ops.torchao.mx_fp4_bf16.default(A, B, A_scale, B_scale) From d370196369e1b1b6424cabaff6627d242dff2268 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Sat, 22 Feb 2025 06:36:17 -0800 Subject: [PATCH 108/115] delete delayed scaling from torchao.float8 (#1753) Update [ghstack-poisoned] --- benchmarks/float8/bench_linear_float8.py | 53 +-- benchmarks/float8/bench_multi_gpu.py | 180 --------- benchmarks/float8/float8_roofline.py | 47 --- benchmarks/float8/profile_linear_float8.py | 44 +-- test/float8/test_base.py | 105 +---- test/float8/test_compile.py | 176 +-------- test/float8/test_fsdp.py | 27 +- test/float8/test_fsdp.sh | 13 +- test/float8/test_fsdp2/test_fsdp2.py | 35 +- test/float8/test_fsdp_compile.py | 8 - test/float8/test_numerics_integration.py | 29 +- torchao/float8/README.md | 67 ---- torchao/float8/__init__.py | 12 - torchao/float8/config.py | 62 +-- torchao/float8/float8_linear.py | 6 +- torchao/float8/float8_linear_utils.py | 234 +---------- torchao/float8/float8_scaling_utils.py | 192 --------- torchao/float8/float8_tensor_parallel.py | 3 +- torchao/float8/float8_utils.py | 53 +-- torchao/float8/fsdp_utils.py | 336 ---------------- torchao/float8/inductor_utils.py | 126 ------ torchao/float8/roofline_utils.py | 113 ++---- torchao/float8/stateful_float8_linear.py | 439 --------------------- torchao/testing/float8/fsdp2_utils.py | 8 - torchao/testing/float8/test_utils.py | 21 - 25 files changed, 93 insertions(+), 2296 deletions(-) delete mode 100644 benchmarks/float8/bench_multi_gpu.py delete mode 100644 torchao/float8/inductor_utils.py delete mode 100644 torchao/float8/stateful_float8_linear.py diff --git a/benchmarks/float8/bench_linear_float8.py b/benchmarks/float8/bench_linear_float8.py index d160d7241d..a7b1e17934 100644 --- a/benchmarks/float8/bench_linear_float8.py +++ b/benchmarks/float8/bench_linear_float8.py @@ -23,10 +23,6 @@ ScalingType, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_linear_utils import ( - linear_requires_sync, - sync_float8_amax_and_scale_history, -) from torchao.float8.float8_tensor import ScaledMMConfig # estimating TOPs for matmuls in fp32, fp16, fp8 @@ -122,39 +118,18 @@ def main( scaling_type_grad_output = ScalingType(scaling_type_grad_output) scaling_granularity = ScalingGranularity(scaling_granularity) - if scaling_type_input is ScalingType.STATIC: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_input = CastConfig( - scaling_type=scaling_type_input, - scaling_granularity=scaling_granularity, - ) - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - scaling_granularity=scaling_granularity, - ) - if scaling_type_grad_output is ScalingType.STATIC: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - static_scale=torch.tensor([1.0], device="cuda"), - scaling_granularity=scaling_granularity, - ) - else: - cast_config_grad_output = CastConfig( - scaling_type=scaling_type_grad_output, - scaling_granularity=scaling_granularity, - ) + cast_config_input = CastConfig( + scaling_type=scaling_type_input, + scaling_granularity=scaling_granularity, + ) + cast_config_weight = CastConfig( + scaling_type=scaling_type_weight, + scaling_granularity=scaling_granularity, + ) + cast_config_grad_output = CastConfig( + scaling_type=scaling_type_grad_output, + scaling_granularity=scaling_granularity, + ) config = Float8LinearConfig( cast_config_input=cast_config_input, @@ -185,7 +160,7 @@ def main( copy.deepcopy(linear_ref), config=config, ) - scaling_repr = f"{linear_float8.scaling_type_repr()},{linear_float8.scaling_granularity_repr()}" + scaling_repr = linear_float8.extra_repr() if fast_accum: linear_float8.forward_config = ScaledMMConfig(False, True, False) @@ -196,8 +171,6 @@ def main( ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() def float8_forw_backward(): - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(linear_float8) linear_float8(input_tensor).sum().backward() def n_times(n, fn, *args, **kwargs): diff --git a/benchmarks/float8/bench_multi_gpu.py b/benchmarks/float8/bench_multi_gpu.py deleted file mode 100644 index 34a690edbe..0000000000 --- a/benchmarks/float8/bench_multi_gpu.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import os -from typing import Callable - -import fire -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn as nn -import torch.utils.benchmark as benchmark -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType -from torchao.float8.float8_linear_utils import ( - convert_to_float8_training, - sync_float8_amax_and_scale_history, -) - -torch.manual_seed(0) - -# TODO: Add more shapes for the benchmark -B, M, K, N = 32, 1024, 1024, 1024 -lr = 0.01 - -config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), -) - - -def benchmark_torch_function_in_microseconds( - func: Callable, - *args, - **kwargs, -) -> float: - t0 = benchmark.Timer( - stmt="func(*args, **kwargs)", - globals={"args": args, "kwargs": kwargs, "func": func}, - ) - return t0.blocked_autorange().median * 1e6 - - -def setup(rank, world_size): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - - # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) - - -def cleanup(): - dist.destroy_process_group() - - -def get_model(K, N, is_fp8, base_dtype=torch.float32): - modules = [ - nn.Linear(K, N, dtype=base_dtype), - nn.ReLU(), - ] - N_LAYERS = 20 - # N linear layers - for _ in range(N_LAYERS - 1): - modules.append(nn.Linear(N, N, dtype=base_dtype)) - modules.append(nn.ReLU()) - m = nn.Sequential(*modules) - if is_fp8: - convert_to_float8_training( - m, - config=config, - ) - return m - - -def fsdp_main(rank, world_size, args): - setup(rank, world_size) - torch.cuda.set_device(rank) - - base_dtype, input_global, compile = args - - # basic distributed data sampling - assert B % world_size == 0 - bsz_local_start = int(rank / world_size * B) - bsz_local_end = int((rank + 1) / world_size * B) - input_tensor = input_global[bsz_local_start:bsz_local_end].to(rank) - - fp8_model = get_model(K, N, is_fp8=True, base_dtype=base_dtype).to(rank) - # Need use_orig_params=True to compile FSDP - fp8_model = FSDP(fp8_model, use_orig_params=True) - fp8_optimizer = torch.optim.SGD(fp8_model.parameters(), lr=lr * world_size) - - # Run one iteration to make compile work, see experiments doc for more context of this issue. - fp8_optimizer.zero_grad() - y_local = fp8_model(input_tensor) - y_local.sum().backward() - fp8_optimizer.step() - sync_float8_amax_and_scale_history(fp8_model) - - sync_float8_func = sync_float8_amax_and_scale_history - if compile: - # TODO: Need to fix issues with compile - fp8_model = torch.compile(fp8_model) - sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) - - def float8_forw_backward(): - fp8_optimizer.zero_grad() - y_local = fp8_model(input_tensor) - y_local.sum().backward() - fp8_optimizer.step() - sync_float8_func(fp8_model) - - ref_model = get_model(K, N, is_fp8=False, base_dtype=base_dtype).to(rank) - ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size) - if compile: - ref_model = torch.compile(ref_model) - - ref_model = FSDP(ref_model, use_orig_params=True) - - def ref_forw_backward(): - ref_optimizer.zero_grad() - ref_model(input_tensor).sum().backward() - ref_optimizer.step() - - def run_n_iterations(n, fn): - for _ in range(n): - fn() - # make sure training is done on all ranks - dist.barrier() - - # warmup - run_n_iterations(50, ref_forw_backward) - run_n_iterations(50, float8_forw_backward) - - N_ITER = 50 - ref_time = ( - benchmark_torch_function_in_microseconds( - run_n_iterations, N_ITER, ref_forw_backward - ) - * 1e-6 - / N_ITER - ) - float8_time = ( - benchmark_torch_function_in_microseconds( - run_n_iterations, N_ITER, float8_forw_backward - ) - * 1e-6 - / N_ITER - ) - - if rank == 0: - print("ref_time", ref_time) - print("float8_time", float8_time) - print("float8 speedup", ref_time / float8_time) - - cleanup() - - -def run(compile: bool): - base_dtype = torch.bfloat16 - WORLD_SIZE = torch.cuda.device_count() - print(f"{base_dtype = }") - print(f"{compile = }") - print(f"{WORLD_SIZE = }") - - # generate input data - ref_input = torch.randn(B, M, K).cuda().to(base_dtype) - # run fsdp model - args = (base_dtype, ref_input, compile) - mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) - - -# Usgae: -# CUDA_VISIBLE_DEVICES=0,1 python benchmarks/bench_multi_gpu.py -if __name__ == "__main__": - fire.Fire(run) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 684ed0af2a..6f30e5eff7 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -58,9 +58,7 @@ ) from torchao.float8 import ( - CastConfig, Float8LinearConfig, - ScalingType, convert_to_float8_training, ) from torchao.float8.roofline_utils import ( @@ -219,24 +217,6 @@ def run( scaling_type_weight="dynamic", scaling_type_grad_output="dynamic", ) - fp8_mem_time_sympy_del_limit = get_float8_mem_sympy( - M, - K, - N, - model_torch_compile_limitations=True, - scaling_type_input="delayed", - scaling_type_weight="delayed", - scaling_type_grad_output="delayed", - ) - fp8_mem_time_sympy_del_nolimit = get_float8_mem_sympy( - M, - K, - N, - model_torch_compile_limitations=False, - scaling_type_input="delayed", - scaling_type_weight="delayed", - scaling_type_grad_output="delayed", - ) if gemm_time_strategy == "roofline": bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16) @@ -258,16 +238,12 @@ def run( # roofline memory overhead estimates "fp8_oh_dyn_limit", "fp8_oh_dyn_nolimit", - "fp8_oh_del_limit", - "fp8_oh_del_nolimit", # actual e2e measurements "bf16_s", "fp8_dyn_s", - "fp8_del_s", "fp8_dyn_axs_s", # 'fp8_lw_s', "fp8_dyn_sp", - "fp8_del_sp", "fp8_dyn_axs_sp", # 'fp8_lw_sp', ] @@ -309,12 +285,6 @@ def run( fp8_mem_time_dyn_nolimit_s = ( fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) ) - fp8_mem_time_del_limit_s = ( - fp8_mem_time_sympy_del_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) - fp8_mem_time_del_nolimit_s = ( - fp8_mem_time_sympy_del_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) - ) # create the model m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16() @@ -333,19 +303,6 @@ def run( m_fp8_dyn = torch.compile(m_fp8_dyn) fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x) - # get the float8 delayed scaling gpu kernel time - torch._dynamo.reset() - config = Float8LinearConfig( - enable_amax_init=False, - enable_pre_and_post_forward=False, - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), - ) - m_fp8_del = convert_to_float8_training(copy.deepcopy(m_orig), config=config) - m_fp8_del = torch.compile(m_fp8_del) - fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x) - # get the float8 dynamic axiswise scaling gpu kernel time torch._dynamo.reset() config = Float8LinearConfig.from_recipe_name("rowwise") @@ -374,16 +331,12 @@ def run( # roofline overhead estimates fp8_mem_time_dyn_limit_s, fp8_mem_time_dyn_nolimit_s, - fp8_mem_time_del_limit_s, - fp8_mem_time_del_nolimit_s, # e2e numbers bf16_time_actual_s, fp8_dyn_time_actual_s, - fp8_del_time_actual_s, fp8_dyn_axs_time_actual_s, # fp8_lw_time_actual_s, bf16_time_actual_s / fp8_dyn_time_actual_s, - bf16_time_actual_s / fp8_del_time_actual_s, bf16_time_actual_s / fp8_dyn_axs_time_actual_s, # bf16_time_actual_s / fp8_lw_time_actual_s, ] diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 687684d4e2..e28ed6dcc2 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -33,19 +33,15 @@ kernel_name_to_category, parse_bw_and_kernel_name, profiler_output_to_filtered_time_by_kernel_name, - profiler_output_to_gpu_time_for_key, update_triton_kernels_in_prof_chome_trace_with_torch_logs, ) -from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( Float8LinearConfig, ScalingType, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -286,9 +282,7 @@ def main( model_type: str = "linear", dtype_filter: str = "both", add_inductor_metadata_to_trace: bool = True, - enable_sync_amax_history: bool = True, enable_activation_checkpointing: bool = False, - enable_float8_delayed_scaling_inductor_passes: bool = False, ): assert model_type in ( "linear", @@ -325,12 +319,6 @@ def main( print( f"enable_activation_checkpointing is set to {enable_activation_checkpointing}" ) - print( - f"enable_float8_delayed_scaling_inductor_passes is set to {enable_float8_delayed_scaling_inductor_passes}" - ) - - if enable_float8_delayed_scaling_inductor_passes: - _prototype_register_float8_delayed_scaling_inductor_passes() device = "cuda" ref_dtype = torch.bfloat16 @@ -388,17 +376,9 @@ def float8_forw(x): out = m_float8(x) return out - sync_amax_history = sync_float8_amax_and_scale_history - def float8_forw_backward_wrapper(x): - # sync_float8_amax_and_scale_history is not full graph torch - # compile friendly, so we add a high level wrapper to allow - # inspection of the fw+bw torch.compile without the scale - # syncing code - # TODO(future): make this better - if linear_requires_sync(config) and enable_sync_amax_history: - with record_function("scale_amax_and_scales"): - sync_amax_history(m_float8) + # TODO(future PR): this wrapper is for delayed scaling, we can clean it + # up now that delayed scaling is deprecated. out = float8_forw(x) # out.sum().backward() is also not torch.compile fullgraph @@ -409,11 +389,6 @@ def float8_forw_backward_wrapper(x): if compile: m_ref = torch.compile(m_ref, fullgraph=True) float8_forw = torch.compile(float8_forw, fullgraph=True) - # Note: it's faster to compile the combination of sync_amax_history wit - # forward because we only look up from dynamo cache once. - # However, compiling the sync function separately makes it more - # convenient to analyze the total time spent on it. - sync_amax_history = torch.compile(sync_amax_history) # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output # to populate triton kernel bandwidth further down in the script @@ -529,13 +504,6 @@ def float8_forw_backward_wrapper(x): ] ) - # get the time spent per user annotation - sync_time_us = profiler_output_to_gpu_time_for_key( - p, "scale_amax_and_scales" - ) - sync_time_ms = sync_time_us / profile_iters / 1e3 - print(f"Sync time ms: {sync_time_ms}") - finally: if f is not None: # print the redirected stdout back to regular stdout @@ -586,14 +554,6 @@ def float8_forw_backward_wrapper(x): df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"] df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"] - # calculate sync time as pct of total float time - # note: this time is not useful if TORCHINDUCTOR_PROFILE is on - total_float8_ms = df_p.iloc[3]["1_float8"] - sync_approx_ratio = sync_time_ms / total_float8_ms - print( - f"\nFloat8 amax/scale sync approx ratio of total time: {sync_approx_ratio:.3f}" - ) - print("\nSummary of time (ms) by kernel category\n\n", df_p) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 818b413a77..463b618fa8 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -26,7 +26,6 @@ from torchao.float8.config import ( - CastConfig, Float8LinearConfig, Float8LinearRecipeName, ScalingGranularity, @@ -37,8 +36,6 @@ from torchao.float8.float8_linear import Float8Linear from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_python_api import addmm_float8_unwrapped from torchao.float8.float8_scaling_utils import ( @@ -55,11 +52,9 @@ from torchao.float8.float8_utils import ( FP8_TYPES, compute_error, - config_has_stateful_scaling, fp8_tensor_statistics, tensor_to_scale, ) -from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config random.seed(0) @@ -285,16 +280,10 @@ def _test_linear_impl( config: Float8LinearConfig, use_ac: bool = False, ): - if config_has_stateful_scaling(config): - m_fp8 = StatefulFloat8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) - else: - m_fp8 = Float8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) + m_fp8 = Float8Linear.from_float( + copy.deepcopy(m_ref), + config, + ) for _ in range(2): if use_ac: @@ -302,8 +291,6 @@ def _test_linear_impl( else: y_fp8 = m_fp8(x) y_fp8.sum().backward() - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m_fp8) if use_ac: y_ref = torch.utils.checkpoint.checkpoint(m_ref, x, use_reentrant=False) @@ -321,65 +308,21 @@ def _test_linear_impl( if m_ref.bias is not None: torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad) - # verify all of the amax buffers got updated - if linear_requires_sync(config): - # only check buffers that are actually used, based on per-tensor - # scaling settings - amax_buffer_names = [] - amax_history_buffer_names = [] - scale_buffer_names = [] - if config.cast_config_input.scaling_type is ScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_input") - amax_history_buffer_names.append("fp8_amax_history_input") - scale_buffer_names.append("fp8_scale_input") - if config.cast_config_weight.scaling_type is ScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_weight") - amax_history_buffer_names.append("fp8_amax_history_weight") - scale_buffer_names.append("fp8_scale_weight") - if config.cast_config_grad_output.scaling_type is ScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_grad_output") - amax_history_buffer_names.append("fp8_amax_history_grad_output") - scale_buffer_names.append("fp8_scale_grad_output") - - # verify all of the amax buffers got updated - max_float8_pos = {torch.finfo(dtype).max for dtype in FP8_TYPES} - for buffer_name in amax_buffer_names: - buffer_value = getattr(m_fp8, buffer_name) - for init_val in max_float8_pos: - assert torch.ne( - buffer_value, torch.tensor(init_val) - ), f"{buffer_name} not filled, current value {buffer_value}" - - # verify all of the amax history buffers got updated - for buffer_name in amax_history_buffer_names: - buffer_value = getattr(m_fp8, buffer_name) - assert torch.max(buffer_value) > 0.0, f"{buffer_name} not filled" - - # verify all of the scale buffers got updated - for buffer_name in scale_buffer_names: - buffer_value = getattr(m_fp8, buffer_name) - assert torch.ne( - buffer_value, torch.tensor(1.0) - ), f"{buffer_name} not filled, current value {buffer_value}" - - # verify initialization flags got updated - assert m_fp8.is_amax_initialized, "Amax was not properly initialized" - @pytest.mark.parametrize( "emulate", [True, False] if is_sm_at_least_89() else [True] ) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( "scaling_type_input", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @@ -467,9 +410,6 @@ def test_autocast_outputs( nn.Linear(32, 32, device="cuda", dtype=linear_dtype), ) config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), emulate=emulate, ) m = convert_to_float8_training(copy.deepcopy(m_ref), config=config) @@ -477,21 +417,15 @@ def test_autocast_outputs( # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) y = m(x) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): y = m(x) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" with torch.autocast("cuda", dtype=torch.bfloat16): y = m(x) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) assert ( y.dtype == torch.bfloat16 ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" @@ -510,40 +444,18 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): # Cast the module to dtype m = m.to(dtype=linear_dtype) - if linear_requires_sync(config): - # Check amax buffer types - for key in [ - "fp8_amax_input", - "fp8_amax_history_input", - "fp8_scale_input", - "fp8_amax_weight", - "fp8_amax_history_weight", - "fp8_scale_weight", - "fp8_amax_grad_output", - "fp8_amax_history_grad_output", - "fp8_scale_grad_output", - ]: - assert ( - m._buffers[key].dtype == torch.float32 - ), f"{key}.dtype is {m._buffers[key].dtype}, expected torch.float32" # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" # autocast on with torch.autocast("cuda"): - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) y = m(x) assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" with torch.autocast("cuda", dtype=torch.bfloat16): - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(m) y = m(x) assert ( y.dtype == torch.bfloat16 @@ -552,7 +464,6 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): def test_repr(self): m = nn.Linear(32, 16) config = Float8LinearConfig( - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), emulate=True, ) m = Float8Linear.from_float( @@ -560,7 +471,7 @@ def test_repr(self): config=config, ) s = m.__repr__() - assert "i:dyn_ten_e4m3,w:del_ten_e4m3,go:dyn_ten_e5m2" in s + assert "i:dyn_ten_e4m3,w:dyn_ten_e4m3,go:dyn_ten_e5m2" in s @unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available") def test_inference_mode(self): diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 0c02db26a6..7c31bf6f08 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -7,7 +7,6 @@ import random import sys import unittest -from dataclasses import replace from io import StringIO import pytest @@ -26,7 +25,6 @@ from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend -from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -35,20 +33,11 @@ e4m3_dtype, ) from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_linear_utils import ( - convert_to_float8_training, - get_float8_layers, - sync_float8_amax_and_scale_history, -) from torchao.float8.float8_scaling_utils import ( - hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig -from torchao.float8.float8_utils import config_has_stateful_scaling -from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config -from torchao.utils import is_fbcode def _test_compile_base( @@ -66,16 +55,10 @@ def _test_compile_base( x_ref = copy.deepcopy(x) m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) - if config_has_stateful_scaling(config): - m_fp8 = StatefulFloat8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) - else: - m_fp8 = Float8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) + m_fp8 = Float8Linear.from_float( + copy.deepcopy(m_ref), + config, + ) m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) m_ref = torch.compile(m_ref, backend=backend, fullgraph=fullgraph) @@ -94,16 +77,14 @@ def _test_compile_base( @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] -) +@pytest.mark.parametrize("scaling_type_input", [ScalingType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @@ -133,16 +114,14 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True]) -@pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] -) +@pytest.mark.parametrize("scaling_type_input", [ScalingType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -171,16 +150,14 @@ def test_aot_eager( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) -@pytest.mark.parametrize( - "scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC] -) +@pytest.mark.parametrize("scaling_type_input", [ScalingType.DYNAMIC]) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @unittest.skipIf( not torch.cuda.is_available() or not is_sm_at_least_89(), @@ -241,16 +218,12 @@ class TestGraphBreaks(DynamoTestCase): class MockLinear(torch.nn.Module): def __init__(self, graph_break: bool): super().__init__() - self.register_buffer("fp8_amax_x", torch.tensor(1.0)) - self.register_buffer("fp8_scale_x", torch.tensor(1.0)) self.graph_break = graph_break def forward(self, x): - x_fp8 = hp_tensor_to_float8_delayed( + x_fp8 = hp_tensor_to_float8_dynamic( x, - self.fp8_scale_x, e4m3_dtype, - self.fp8_amax_x, LinearMMConfig(), ) if self.graph_break: @@ -330,30 +303,6 @@ def test_float8_graph_output(self): ) -@unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), - "CUDA with float8 support not available", -) -def test_sync_amax_func(): - torch._dynamo.reset() - cnts = CompileCounterWithBackend("inductor") - module = torch.nn.Sequential( - nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) - ) - config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), - ) - float8_mod = convert_to_float8_training( - module, - config=config, - ) - compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts) - compiled_swap_func(float8_mod) - assert cnts.frame_count == 1, "Compiled graph should have 1 frame!" - - class capture_stderr(list): """ Replace sys.stderr with a temporary StringIO @@ -371,38 +320,6 @@ def __exit__(self, *args): sys.stderr = self.sys_stderr -@unittest.skipIf( - not torch.cuda.is_available() or not is_sm_at_least_89(), - "CUDA with float8 support not available", -) -def test_sync_amax_func_cuda_graph_success(): - torch._dynamo.reset() - with capture_stderr() as stderr: - my_module = nn.Sequential( - nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) - ).to("cuda") - config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), - ) - convert_to_float8_training( - my_module, - config=config, - ) - inpt = torch.randn( - 16, 16, device="cuda", dtype=torch.float32, requires_grad=True - ) - sync_func = torch.compile( - sync_float8_amax_and_scale_history, mode="reduce-overhead", fullgraph=True - ) - fp8_layers = get_float8_layers(my_module) - my_module(inpt) - sync_func(my_module, fp8_layers) - - assert "skipping cudagraphs due to mutaton on input" not in stderr[0] - - @unittest.skipIf( not is_sm_at_least_89(), "CUDA not available", @@ -475,70 +392,5 @@ def test_dynamic_scale_numeric_parity( assert torch.equal(float8_eager._data, float8_compile._data) -@unittest.skipIf( - not is_sm_at_least_89() or not is_fbcode(), - "CUDA with float8 support not available; or not on fbcode (the test needs be run with the latest pytorch package)", -) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) -def test_delayed_scaling_pattern_replacement(dtype: torch.dtype): - from torch._inductor import config as inductor_config - from torch._inductor import metrics - - inductor_config.loop_ordering_after_fusion = True - - def clear_all(): - metrics.reset() - from torch._inductor.fx_passes.post_grad import ( - pass_patterns as post_grad_patterns_all, - ) - - post_grad_patterns_all[1].clear() - post_grad_patterns_all[1].seen_patterns.clear() - - def compile_and_run_single_layer(): - random.seed(0) - torch.manual_seed(0) - x_shape = (2048, 3072) - linear_dtype = dtype - - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_() - m_ref = nn.Linear(3072, 2048, bias=True, device="cuda", dtype=linear_dtype) - - config = get_test_float8_linear_config( - ScalingType.DELAYED, - ScalingType.DELAYED, - ScalingType.DELAYED, - False, - ) - - config = replace(config, enable_amax_init=False) - - m_fp8 = StatefulFloat8Linear.from_float( - copy.deepcopy(m_ref), - config, - ) - - m_fp8 = torch.compile(m_fp8, backend="inductor", fullgraph=True) - m_ref = torch.compile(m_ref, backend="inductor", fullgraph=True) - - y_fp8 = m_fp8(x) - y_fp8.sum().backward() - - return m_fp8.weight.grad - - clear_all() - ref_output = compile_and_run_single_layer() - ref_count_kernel = metrics.generated_kernel_count - - clear_all() - _prototype_register_float8_delayed_scaling_inductor_passes() - new_output = compile_and_run_single_layer() - new_count_kernel = metrics.generated_kernel_count - - torch.equal(ref_output, new_output) - # With the pattern replacement workaround, amax reduction kernels for the 3 tensors (weight, activation, gradient) are fused. - assert ref_count_kernel == new_count_kernel + 3 - - if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/float8/test_fsdp.py b/test/float8/test_fsdp.py index 863256dc35..3017c8b539 100644 --- a/test/float8/test_fsdp.py +++ b/test/float8/test_fsdp.py @@ -35,11 +35,9 @@ FullyShardedDataParallel as FSDP, ) -from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType +from torchao.float8.config import Float8LinearConfig from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_utils import compute_error @@ -77,19 +75,13 @@ def get_model(K, N, base_dtype=torch.float32): def fsdp_main(rank, world_size, args): setup(rank, world_size) torch.cuda.set_device(rank) + print("args", args) - emulate, base_dtype, compile, use_weight_dynamic_scaling = args + emulate, base_dtype, compile = args model = get_model(K, N, base_dtype=base_dtype).to(rank) model_fp8 = copy.deepcopy(model) - scaling_type_weight = ( - ScalingType.DYNAMIC if use_weight_dynamic_scaling else ScalingType.DELAYED - ) - config = Float8LinearConfig( - cast_config_weight=CastConfig(scaling_type=scaling_type_weight), - # TODO(future): delete this arg as it's always False - emulate=False, - ) + config = Float8LinearConfig() # Note: we only iterate over `scaling_type_weight` because FSDP only interacts # with weights. @@ -110,6 +102,7 @@ def fsdp_main(rank, world_size, args): # Note: we need two different inputs to properly measure the impact of # delayed scaling, before the first input uses dynamic scaling to # populate the buffers + # TODO(future PR): delete ^, since we deleted delayed scaling ref_input_global = [ torch.randn(B, M, K).cuda().to(base_dtype), torch.randn(B, M, K).cuda().to(base_dtype), @@ -133,16 +126,10 @@ def fsdp_main(rank, world_size, args): ref_grad_global[idx][bsz_local_start:bsz_local_end].to(rank) ) - sync_float8_func = sync_float8_amax_and_scale_history - if compile: - sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) - def forward_backward(model, optim, is_fp8, i): optim.zero_grad() y_local = model(ref_input_local[i]) y_local.backward(ref_grad_local[i]) - if is_fp8 and linear_requires_sync(config): - sync_float8_func(model) optim.step() return y_local @@ -193,7 +180,7 @@ def forward_backward(model, optim, is_fp8, i): cleanup() -def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False): +def run(compile_fsdp: bool = False): base_dtype = torch.bfloat16 emulate = False @@ -207,7 +194,7 @@ def run(compile_fsdp: bool = False, use_weight_dynamic_scaling: bool = False): emulate = True WORLD_SIZE = torch.cuda.device_count() - args = (emulate, base_dtype, compile_fsdp, use_weight_dynamic_scaling) + args = (emulate, base_dtype, compile_fsdp) mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) diff --git a/test/float8/test_fsdp.sh b/test/float8/test_fsdp.sh index 3ff19d917d..6f135a2e76 100755 --- a/test/float8/test_fsdp.sh +++ b/test/float8/test_fsdp.sh @@ -4,12 +4,12 @@ set -e launch() { - echo "launching compile_fsdp $COMPILE, use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING" + echo "launching compile_fsdp $COMPILE" # the NCCL_DEBUG setting is to avoid log spew # the CUDA_VISIBLE_DEVICES setting is for easy debugging NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp.py \ - --compile_fsdp $COMPILE --use_weight_dynamic_scaling $USE_WEIGHT_DYNAMIC_SCALING + --compile_fsdp $COMPILE echo "✅ All Tests Passed ✅" } @@ -19,10 +19,5 @@ if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; exit fi -# COMPILE, USE_WEIGHT_DYNAMIC_SCALING -for i in False,False False,True True,False True,True -do - IFS=","; set -- $i; - COMPILE=$1; USE_WEIGHT_DYNAMIC_SCALING=$2 - launch -done +COMPILE=False launch +COMPILE=True launch diff --git a/test/float8/test_fsdp2/test_fsdp2.py b/test/float8/test_fsdp2/test_fsdp2.py index 0beb012406..a36fc3e249 100644 --- a/test/float8/test_fsdp2/test_fsdp2.py +++ b/test/float8/test_fsdp2/test_fsdp2.py @@ -104,7 +104,6 @@ def test_transformer_parity(self): "precompute": [False, True], "scaling_type_weight": [ ScalingType.DYNAMIC, - ScalingType.DELAYED, ], "compile_transformer_block": [False, True], "dtype": [torch.float32, torch.bfloat16], @@ -122,8 +121,6 @@ def _test_transformer_parity( ): if not enable_fsdp_float8_all_gather and precompute: return - elif scaling_type_weight is ScalingType.DELAYED and precompute: - return # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the @@ -465,16 +462,10 @@ def test_fp32_fp8_single_module_parity(self): """ choices = itertools.product( [False, True], - [ScalingType.DYNAMIC, ScalingType.DELAYED, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) for enable_fsdp_float8_all_gather, scaling_type_weight in choices: - if scaling_type_weight is ScalingType.STATIC: - cast_config_weight = CastConfig( - scaling_type=scaling_type_weight, - static_scale=torch.tensor([1.0], device="cuda"), - ) - else: - cast_config_weight = CastConfig(scaling_type=scaling_type_weight) + cast_config_weight = CastConfig(scaling_type=scaling_type_weight) float8_linear_config1 = Float8LinearConfig( enable_fsdp_float8_all_gather=False, @@ -517,7 +508,7 @@ def test_fp32_fp8_multi_module_parity(self): """ choices = itertools.product( [False, True], - [ScalingType.DYNAMIC, ScalingType.DELAYED], + [ScalingType.DYNAMIC], ) for enable_fsdp_float8_all_gather, scaling_type_weight in choices: float8_linear_config1 = Float8LinearConfig( @@ -587,26 +578,6 @@ def test_bf16_mp_fp8_dynamic_multi_parity(self): self.get_local_inp(torch.bfloat16), ) - @unittest.skipIf(not TEST_CUDA, "no cuda") - def test_delayed_scaling_inplace_update(self): - """ - Verify that `WeightWithDelayedFloat8CastTensor` updates buffers inplace - """ - module = self.init_single_module() - float8_linear_config = Float8LinearConfig( - enable_fsdp_float8_all_gather=True, - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - ) - m_fp8 = convert_to_float8_training( - module, - config=float8_linear_config, - ) - - fp8_amax_weight_old = m_fp8.fp8_amax_weight.clone().detach() - dummy_mesh = None - data, scale = m_fp8.weight.fsdp_pre_all_gather(dummy_mesh) - self.assertNotEqual(fp8_amax_weight_old.item(), m_fp8.fp8_amax_weight.item()) - if __name__ == "__main__": run_tests() diff --git a/test/float8/test_fsdp_compile.py b/test/float8/test_fsdp_compile.py index 1d95801f67..a78a30925c 100644 --- a/test/float8/test_fsdp_compile.py +++ b/test/float8/test_fsdp_compile.py @@ -26,10 +26,8 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torchao.float8 import Float8LinearConfig -from torchao.float8.config import CastConfig, ScalingType from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - sync_float8_amax_and_scale_history, ) torch.manual_seed(0) @@ -63,10 +61,6 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): # https://gist.github.com/vkuzo/ed8e168fd9f7463f1fce34301334ab55 # to get around this, we can disable amax init config = Float8LinearConfig( - enable_amax_init=False, - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), emulate=emulate, ) @@ -102,7 +96,6 @@ def fsdp_main(rank, world_size, args): optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) input_local = torch.randn(B, M, K, N, device="cuda") - sync_float8_func = torch.compile(sync_float8_amax_and_scale_history) model = torch.compile(model) @@ -111,7 +104,6 @@ def fsdp_main(rank, world_size, args): with torch.autocast("cuda"): y_local = model(input_local) y_local.sum().backward() - sync_float8_func(model) optimizer.step() print("done!") diff --git a/test/float8/test_numerics_integration.py b/test/float8/test_numerics_integration.py index 01e4cbb20d..f25c876189 100644 --- a/test/float8/test_numerics_integration.py +++ b/test/float8/test_numerics_integration.py @@ -31,8 +31,6 @@ ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_utils import IS_ROCM, compute_error from torchao.testing.float8.test_utils import get_test_float8_linear_config @@ -115,7 +113,7 @@ def _test_impl(self, config: Float8LinearConfig) -> None: # Note: you need two different inputs to properly test numerics # of delayed scaling, because the first time around the initialization # logic of delayed scaling behaves as dynamic scaling - # TODO(future): also make unit tests do this properly + # TODO(future PR): delete ^, since we deleted delayed scaling shape = (1, 8192, 4096) data1 = torch.randn(*shape, device="cuda", dtype=data_dtype) data2 = torch.randn(*shape, device="cuda", dtype=data_dtype) @@ -127,36 +125,21 @@ def _test_impl(self, config: Float8LinearConfig) -> None: model_ref_out = model_ref(data2) model_ref_out.sum().backward() - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(model_fp8) model_fp8(data1).sum().backward() # zero out grads without stepping, since we just want to compare grads # of the second datum optim_fp8.zero_grad() - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(model_fp8) model_fp8_out = model_fp8(data2) model_fp8_out.sum().backward() out_sqnr = compute_error(model_ref_out, model_fp8_out) - any_static_scaling = ( - config.cast_config_input.scaling_type is ScalingType.STATIC - or config.cast_config_weight.scaling_type is ScalingType.STATIC - or config.cast_config_grad_output.scaling_type is ScalingType.STATIC - ) - if any_static_scaling: - assert out_sqnr > 10.0 - else: - assert out_sqnr > 20.0 + assert out_sqnr > 20.0 ref_name_to_grad = { name: param.grad for name, param in model_ref.named_parameters() } - if any_static_scaling: - grad_sqnr_threshold = 10.0 - else: - grad_sqnr_threshold = 20.0 + grad_sqnr_threshold = 20.0 for name, param in model_fp8.named_parameters(): ref_grad = ref_name_to_grad[name] @@ -166,15 +149,15 @@ def _test_impl(self, config: Float8LinearConfig) -> None: @pytest.mark.parametrize( "scaling_type_input", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_weight", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.parametrize( "scaling_type_grad_output", - [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC], + [ScalingType.DYNAMIC], ) @pytest.mark.skipif( not is_sm_at_least_89(), reason="requires SM89 compatible machine" diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 4dbc556d83..65105d1f89 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -15,8 +15,6 @@ throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs. # Single GPU User API -We provide three per-tensor scaling strategies: dynamic, delayed and static. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`). - ## float8 linear with dynamic tensorwise scaling This is the default recipe, with a good balance of performance and accuracy. @@ -114,67 +112,6 @@ for _ in range(10): optimizer.step() ``` -## float8 linear with delayed scaling - -:warning: We plan to deprecate delayed scaling in a future release, see https://github.com/pytorch/ao/issues/1680 for more details. - -This is theoretically the most performant recipe as it minimizes memory reads. - -```python -import torch -import torch.nn as nn -from torchao.float8 import ( - convert_to_float8_training, - sync_float8_amax_and_scale_history, - Float8LinearConfig, - ScalingType, - CastConfig, -) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 - -if not TORCH_VERSION_AT_LEAST_2_5: - raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") - -# Recommended: enable additional torchinductor passes to improve the performance of delayed scaling -torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes() - -# create model and sample input -m = nn.Sequential( - nn.Linear(2048, 4096), - nn.Linear(4096, 128), -).bfloat16().cuda() -x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) -optimizer = torch.optim.SGD(m.parameters(), lr=0.1) - -# configure delayed scaling -config = Float8LinearConfig( - cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), - cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), -) - -# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior -convert_to_float8_training(m, config=config) - -# enable torch.compile for competitive performance -m = torch.compile(m) - -# toy training loop -for _ in range(10): - optimizer.zero_grad() - y = m(x) - y.sum().backward() - - # Specific to delayed scaling: separate step to sync scales/amaxes. - # On the first call, this function also sets the `is_amax_initialized` flag to - # mark the amax and scale buffers as initialized. - # Make sure you run this after every model forward+backward pass. - # In the future, this may move to a context manager. - sync_float8_amax_and_scale_history(m) - - optimizer.step() -``` - # Multi GPU User API We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/stable/distributed.tensor.parallel.html), @@ -226,10 +163,6 @@ There are three observations we can make about the formula above: For small shapes, a combination of (2) and (3) leads to speedup < 1. For medium shapes, (1) and (3) are of similar magnitude and the speedup depends on M, K, N and framework and compiler behavior. For large shapes, (1) leads to speedup > 1. -## Scaling type vs speedup - -Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling without workarounds. We have a prototype workaround (API subject to change) with the `torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes()` API to improve delayed scaling performance. - ## torch.compile behavior vs speedup There are a couple of limitations in how torch.compile generates float8 scaling and casting kernels (see the performance section of https://github.com/pytorch/ao/issues/556). As the limitations get resolved, we expect to reach improved performance. diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 258db53be0..18ef82a507 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -6,15 +6,12 @@ # Lets define a few top level things here from torchao.float8.config import ( CastConfig, - DelayedScalingConfig, Float8GemmConfig, Float8LinearConfig, ScalingType, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, - linear_requires_sync, - sync_float8_amax_and_scale_history, ) from torchao.float8.float8_tensor import ( Float8Tensor, @@ -23,11 +20,7 @@ ScaledMMConfig, ) from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp -from torchao.float8.inductor_utils import ( - _prototype_register_float8_delayed_scaling_inductor_passes, -) from torchao.float8.inference import Float8MMConfig -from torchao.float8.stateful_float8_linear import WeightWithDelayedFloat8CastTensor from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 if TORCH_VERSION_AT_LEAST_2_5: @@ -41,22 +34,17 @@ GemmInputRole, LinearMMConfig, Float8MMConfig, - WeightWithDelayedFloat8CastTensor, ] ) __all__ = [ # configuration - "DelayedScalingConfig", "ScalingType", "Float8GemmConfig", "Float8LinearConfig", "CastConfig", # top level UX "convert_to_float8_training", - "linear_requires_sync", - "sync_float8_amax_and_scale_history", "precompute_float8_dynamic_scale_for_fsdp", - "_prototype_register_float8_delayed_scaling_inductor_passes", # note: Float8Tensor and Float8Linear are not public APIs ] diff --git a/torchao/float8/config.py b/torchao/float8/config.py index fa03d55b11..d2998d890f 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -15,20 +15,14 @@ class ScalingType(enum.Enum): - DELAYED = "delayed" DYNAMIC = "dynamic" - STATIC = "static" # ScalingType.DISABLED means "skip scaling for this tensor, leave it in # its original precision. DISABLED = "disabled" def short_str(self): - if self is ScalingType.DELAYED: - return "del" - elif self is ScalingType.DYNAMIC: + if self is ScalingType.DYNAMIC: return "dyn" - elif self is ScalingType.STATIC: - return "sta" else: assert self is ScalingType.DISABLED return "dis" @@ -90,7 +84,6 @@ class CastConfig: scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE - static_scale: Optional[torch.Tensor] = None target_dtype: Optional[torch.dtype] = None def short_str(self): @@ -98,10 +91,6 @@ def short_str(self): return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}_{dtype}" def __post_init__(self): - if self.scaling_type is ScalingType.STATIC: - assert ( - self.static_scale is not None - ), "static_scale must be specified for static scaling" if self.scaling_granularity is ScalingGranularity.AXISWISE: assert ( self.scaling_type is ScalingType.DYNAMIC @@ -111,30 +100,6 @@ def __post_init__(self): ), "must specify a 8-bit floating-point dtype" -@dataclass(frozen=True) -class DelayedScalingConfig: - """ - Configuration for delayed scaling. - - Note: for now, `history_len` values must be the same for all layers in the - model using delayed scaling. - - TODO(future): serialization for recipes - """ - - # Controls the history length of amax buffers - history_len: int = 16 - - # Controls the way to calculate current scale from amax history - # TODO(future): add other functions as needed, hardcoded or user defined - scale_fn_name: str = "max" - - def __post_init__(self): - assert ( - self.scale_fn_name == "max" - ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." - - @dataclass(frozen=True) class Float8GemmConfig: """ @@ -215,14 +180,6 @@ class Float8LinearConfig: # Per-linear configuration # - # This configuration option is deprecated and no longer has an effect. It may - # be removed in a future release. - enable_amax_init: bool = True - - # This configuration option is deprecated and no longer has an effect. It may - # be removed in a future release. - enable_pre_and_post_forward: bool = True - # If True, then uses a tensor subclass for the float8 linear module's weight that # implements pre/post-all-gather methods to do float8 all-gather with FSDP2. enable_fsdp_float8_all_gather: bool = False @@ -236,13 +193,6 @@ class Float8LinearConfig: # If True, emulation is used instead of hardware accelerated gemm emulate: bool = False - # Configuration for delayed scaling - # Note: this is actually applied per-tensor, but only using the same - # configuration for all tensors and layers in the model is currently - # supported. If in the future we add support for a more fine grained - # configuration, this field may move to per-tensor configs. - delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() - # If the option is enabled, fp8_weight will always be re-computed in backward. # It's recommended to enable this flag when using FSDP. # Otherwise, the entire fp8_weight, instead of the sharded weight may be saved. @@ -336,16 +286,6 @@ def __post_init__(self): "When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd." ) - # Future deprecation warning for delayed scaling - if ( - self.cast_config_input.scaling_type != ScalingType.DYNAMIC - or self.cast_config_weight.scaling_type != ScalingType.DYNAMIC - or self.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC - ): - logger.warning( - "Note: delayed and static scaling will be deprecated in a future release of torchao. Please see https://github.com/pytorch/ao/issues/1680 for more details." - ) - @staticmethod def from_recipe_name( recipe_name: Union[Float8LinearRecipeName, str], diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index d822d33042..9d5cdd3242 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -64,8 +64,6 @@ class matmul_with_hp_or_float8_args(torch.autograd.Function): * if the arguments are in high precision, they are cast to float8 according to the specified config * if the arguments are in float8, we assume the cast honored the config - - Only supports dynamic scaling, does not support delayed/static scaling. """ @staticmethod @@ -259,8 +257,7 @@ class Float8Linear(torch.nn.Linear): inside of this repository. Please file an issue if you would benefit from this being a public API. - A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks - scales in way friendly to delayed scaling. + A wrapper around a `torch.nn.Linear` module which does fp8 compute. """ def __init__(self, *args, **kwargs): @@ -411,6 +408,7 @@ def from_float( # 1. weight needs to be on the correct device to create the buffers # 2. buffers need to be already created for the delayed scaling version # of the weight wrapper to be initialized + # TODO(future PR): see if we can simplify ^ now that delayed scaling is deleted if config.enable_fsdp_float8_all_gather: assert config.cast_config_weight.scaling_type is ScalingType.DYNAMIC new_mod.weight = torch.nn.Parameter( diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index 3649b741cc..db9889567f 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -6,56 +6,15 @@ import logging from typing import Callable, Optional -import torch -import torch.distributed as dist import torch.nn as nn -from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import Float8LinearConfig, ScalingType +from torchao.float8.config import Float8LinearConfig from torchao.float8.float8_linear import Float8Linear -from torchao.float8.float8_utils import ( - amax_history_to_scale_stack, - config_has_stateful_scaling, -) -from torchao.float8.stateful_float8_linear import StatefulFloat8Linear log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) -def linear_requires_sync(config: Float8LinearConfig): - """Returns whether the given linear_type requires sync before forward.""" - return any( - [ - config.cast_config_input.scaling_type is ScalingType.DELAYED, - config.cast_config_weight.scaling_type is ScalingType.DELAYED, - config.cast_config_grad_output.scaling_type is ScalingType.DELAYED, - ] - ) - - -def _update_history_stack( - new_amax: torch.Tensor, amax_history_stack: torch.Tensor -) -> torch.Tensor: - """ - Updates `amax_history` (the last N cur_amax values) inplace with the value - of `new_amax`. - - Args: - new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1) - amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length) - """ - assert ( - amax_history_stack.dim() == 2 - ), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}" - assert ( - new_amax.size(0) == amax_history_stack.size(0) - ), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}" - new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1) - new_amax_history_stack[:, 0] = new_amax.squeeze(-1) - amax_history_stack.copy_(new_amax_history_stack) - - def swap_linear_layers( module: nn.Module, from_float_func: Callable[[nn.Linear], nn.Linear], @@ -144,196 +103,13 @@ def convert_to_float8_training( if config is None: config = Float8LinearConfig() - if config_has_stateful_scaling(config): - from_float = lambda m: StatefulFloat8Linear.from_float( - m, - config=config, - ) - else: - from_float = lambda m: Float8Linear.from_float( - m, - config=config, - ) + from_float = lambda m: Float8Linear.from_float( + m, + config=config, + ) return swap_linear_layers( module, from_float, module_filter_fn=module_filter_fn, ) - - -def get_float8_layers(model: torch.nn.Module): - """Iterates through the model and returns all the Float8Linear layers. - Args: - model (torch.nn.Module): The model to look for Float8Linear layers in. - """ - - # Get all fp8 layers and tensors - fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)] - if not torch.compiler.is_compiling(): - for layer in fp8_layers: - for buf in layer.buffers(): - torch._dynamo.mark_static_address(buf, guard=True) - return fp8_layers - - -@torch.no_grad() -def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None: - """ - Manages the float8 amax and scale bookkeeping. In detail, it does the - following: - 1. in distributed contexts, syncs amax values across workers for activations and gradients - 2. adds the `amax` values to history - 3. calculates the scales to be used for next iteration - 4. sets the `amax_and_scale_synced` flag on the Float8Linear modules - to signal that they have been synced - - TODO(future): design the UX for this (context manager, etc) - - PERFORMANCE NOTE: - When you can, it is much more efficient to call get_float8_layers once at - the beginning of the training loop and pass the result to this function. - Because of how this interacts with torch.compile - - Args: - model (torch.nn.Module): The model to track amaxes for - fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored, - and we loop over all fp8_layers to sync and update amax scale histories. - Users can use get_float8_layers to get all fp8 layers. - """ - # TODO(future): consider adding a flag to control setting the `is_amax_initialized` - # flag only on the first iteration. - - if fp8_layers is None: - fp8_layers = get_float8_layers(model) - - if len(fp8_layers) == 0: - log.warn( - "Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers" - ) - return - - def inner_func(): - """Why do we have this inner_function? - - There are two portions of the outer sync_function that cause graph_breaks: - 1. The `get_float8_layers` call can cause graph breaks if the user did not pass - in the fp8_layers. - 2. At the end of syncing all the amaxes and scales we set the attr on the module - signaling that we have synced the amaxes and scales and the next forward can be run. - # TODO Maybe we should remove this safety check to remove the graph break? - - By having this inner function, we can ensure that although the outer function may cause graph breaks - the inner function will not. - """ - # Loop over all fp8 layers and grab the needed tensors - fp8_amax_input_tensor_list = [None] * len(fp8_layers) - fp8_amax_weight_tensor_list = [None] * len(fp8_layers) - fp8_amax_grad_output_tensor_list = [None] * len(fp8_layers) - - fp8_input_amax_history_stack = [None] * len(fp8_layers) - fp8_weight_amax_history_stack = [None] * len(fp8_layers) - fp8_grad_output_amax_history_stack = [None] * len(fp8_layers) - - input_dtypes = set() - weight_dtypes = set() - grad_output_dtypes = set() - scale_fn_recipes = set() - - for idx, child in enumerate(fp8_layers): - fp8_amax_input_tensor_list[idx] = child.fp8_amax_input - fp8_amax_weight_tensor_list[idx] = child.fp8_amax_weight - fp8_amax_grad_output_tensor_list[idx] = child.fp8_amax_grad_output - - fp8_input_amax_history_stack[idx] = child.fp8_amax_history_input - fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight - fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output - - input_dtypes.add(child.config.cast_config_input.target_dtype) - weight_dtypes.add(child.config.cast_config_weight.target_dtype) - grad_output_dtypes.add(child.config.cast_config_grad_output.target_dtype) - scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) - - (input_dtype,) = input_dtypes - (weight_dtype,) = weight_dtypes - (grad_output_dtype,) = grad_output_dtypes - - if len(scale_fn_recipes) != 1: - raise ValueError( - f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" - ) - scale_fn_recipe = next(iter(scale_fn_recipes)) - - assert ( - len(fp8_amax_input_tensor_list) - == len(fp8_amax_weight_tensor_list) - == len(fp8_amax_grad_output_tensor_list) - ), "Mismatched lengths of amax tensors." - - if dist.is_initialized(): - all_amax_tensors = torch.cat( - fp8_amax_input_tensor_list - + fp8_amax_weight_tensor_list - + fp8_amax_grad_output_tensor_list - ) - all_reduced_amax_tensor = all_reduce( - all_amax_tensors, "MAX", list(range(dist.get_world_size())) - ) - if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): - all_reduced_amax_tensor = all_reduced_amax_tensor.wait() - - ( - reduced_fp8_amax_input_tensor, - reduced_fp8_amax_weight_tensor, - reduced_fp8_amax_grad_output_tensor, - ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_input_tensor_list)) - - for idx, child in enumerate(fp8_layers): - child.fp8_amax_input.copy_(reduced_fp8_amax_input_tensor[idx]) - child.fp8_amax_weight.copy_(reduced_fp8_amax_weight_tensor[idx]) - child.fp8_amax_grad_output.copy_( - reduced_fp8_amax_grad_output_tensor[idx] - ) - - # We create two stacked tensor groups, one for the amax history and one for the current scales - fp8_amax_input_tensors = torch.vstack(fp8_amax_input_tensor_list) - fp8_amax_weight_tensors = torch.vstack(fp8_amax_weight_tensor_list) - fp8_amax_grad_output_tensors = torch.vstack(fp8_amax_grad_output_tensor_list) - - fp8_input_amax_history_stack = torch.vstack(fp8_input_amax_history_stack) - fp8_weight_amax_history_stack = torch.vstack(fp8_weight_amax_history_stack) - fp8_grad_output_amax_history_stack = torch.vstack( - fp8_grad_output_amax_history_stack - ) - - # Update the history stacks with the new amax values - _update_history_stack(fp8_amax_input_tensors, fp8_input_amax_history_stack) - _update_history_stack(fp8_amax_weight_tensors, fp8_weight_amax_history_stack) - _update_history_stack( - fp8_amax_grad_output_tensors, fp8_grad_output_amax_history_stack - ) - - # Calculate the new scales from the updated history stacks - new_input_scales = amax_history_to_scale_stack( - fp8_input_amax_history_stack, input_dtype, scale_fn_recipe - ) - new_weight_scales = amax_history_to_scale_stack( - fp8_weight_amax_history_stack, weight_dtype, scale_fn_recipe - ) - new_grad_output_scales = amax_history_to_scale_stack( - fp8_grad_output_amax_history_stack, grad_output_dtype, scale_fn_recipe - ) - - # Iterate through the layers and update the scales - for idx, child in enumerate(fp8_layers): - child.fp8_scale_input.copy_(new_input_scales[idx]) - child.fp8_scale_weight.copy_(new_weight_scales[idx]) - child.fp8_scale_grad_output.copy_(new_grad_output_scales[idx]) - - # This allows for the compile to succeed on the inner func and fail on the graph breaks - # at the beginning and and of syncing - inner_func() - - for child in fp8_layers: - # Set a flag to signal that initialization is done - child.is_amax_initialized = True diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index b96c7a9b58..31f2db6b4e 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -21,8 +21,6 @@ hp_tensor_and_scale_to_float8, ) from torchao.float8.float8_utils import ( - amax_history_to_scale, - tensor_to_amax, tensor_to_scale, ) @@ -74,72 +72,6 @@ def hp_tensor_to_float8_dynamic( ) -def hp_tensor_to_float8_delayed( - hp_tensor: torch.Tensor, - s: torch.Tensor, - float8_dtype: torch.dtype, - amax_buffer: torch.Tensor, - linear_mm_config: Optional[LinearMMConfig] = None, - gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, -) -> Float8Tensor: - """ - Given a high precision tensor `hp_tensor` and relevant metadata, scales it using - delayed scaling and returns a `Float8Tensor` of the result. Specifically: - 1. calculates max(abs(hp_tensor)) and stores the result in `amax_buffer`, inplace - 2. scales `hp_tensor` by `s` and returns the result wrapped in Float8Tensor - - Args: - hp_tensor: the tensor to convert - s: the scale to use to convert the tensor - float8_dtype: the float8 dtype to use - amax_buffer: the buffer to modify inplace with max(abs(hp_tensor)) - linear_mm_config: Defines the configuration for the scaled_mm for - the 3 fwd/bwd gemms of linear - gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in - the 3 fwd/bwd gemms of linear - """ - amax_buffer.fill_(tensor_to_amax(hp_tensor)) - return hp_tensor_and_scale_to_float8( - hp_tensor, - s, - float8_dtype, - linear_mm_config, - gemm_input_role, - ) - - -def hp_tensor_to_float8_static( - hp_tensor: torch.Tensor, - scale: torch.Tensor, - float8_dtype: torch.dtype, - linear_mm_config: LinearMMConfig, - gemm_input_role: GemmInputRole = GemmInputRole.INPUT, -) -> Float8Tensor: - """ - Given a high precision tensor `hp_tensor` and a scale, - scales `hp_tensor` returns a `Float8Tensor` of the result. - - Args: - hp_tensor: the tensor to convert - scale: the scale to use - float8_dtype: the float8 dtype to use - linear_mm_config: Defines the configuration for the scaled_mm for - the 3 fwd/bwd gemms of linear - gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in - the 3 fwd/bwd gemms of linear - """ - if tensor_already_casted_to_fp8(hp_tensor): - return hp_tensor - - return hp_tensor_and_scale_to_float8( - hp_tensor, - scale, - float8_dtype, - linear_mm_config, - gemm_input_role, - ) - - def get_maybe_axiswise_dim( axiswise_dim: int, scaling_granularity: ScalingGranularity, @@ -155,95 +87,6 @@ def get_maybe_axiswise_dim( return None -def _maybe_initialize_amaxes_scales_for_float8_cast( - x, - cur_amax, - amax_history, - scale, - scale_fn_name, - float8_dtype, - is_initialized, - reduce_amax, -): - """ - If x is about to be cast to `float8` and the amax buffers are not initialized, - initializes them inplace. - """ - if is_initialized: - return - with torch.no_grad(): - # Note: we need to enable distributed reduction here in order - # to match numerics between single GPU and multi GPU code for - # activations and gradients - new_amax = tensor_to_amax(x, reduce_amax=reduce_amax) - cur_amax.fill_(new_amax) - amax_history[0] = new_amax - new_scale = amax_history_to_scale(amax_history, float8_dtype, scale_fn_name) - scale.copy_(new_scale) - - -@torch._dynamo.allow_in_graph -class NoopFwToFloat8BwDelayed(torch.autograd.Function): - """ - Forward: no-op - Backward: convert to float8_e5m2 with delayed scaling, initialize if needed - """ - - @staticmethod - def forward( - ctx, - tensor, - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - scale_fn_name, - is_amax_initialized, - linear_mm_config: LinearMMConfig, - target_dtype: torch.dtype, - ): - ctx.save_for_backward( - fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output - ) - ctx.scale_fn_name = scale_fn_name - ctx.is_amax_initialized = is_amax_initialized - ctx.linear_mm_config = linear_mm_config - ctx.target_dtype = target_dtype - return tensor - - @staticmethod - def backward(ctx, go): - ( - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - ) = ctx.saved_tensors - scale_fn_name = ctx.scale_fn_name - is_amax_initialized = ctx.is_amax_initialized - - _maybe_initialize_amaxes_scales_for_float8_cast( - go, - fp8_amax_grad_output, - fp8_amax_history_grad_output, - fp8_scale_grad_output, - scale_fn_name, - ctx.target_dtype, - is_amax_initialized, - reduce_amax=True, - ) - - fp8_amax_grad_output.fill_(tensor_to_amax(go)) - - res = hp_tensor_and_scale_to_float8( - go, - fp8_scale_grad_output, - ctx.target_dtype, - ctx.linear_mm_config, - GemmInputRole.GRAD_OUTPUT, - ) - empty_grads = None, None, None, None, None, None, None - return res, *empty_grads - - @torch._dynamo.allow_in_graph class NoopFwToFloat8BwDynamic(torch.autograd.Function): """ @@ -275,38 +118,3 @@ def backward(ctx, gradY): GemmInputRole.GRAD_OUTPUT, ) return fp8_tensor, None, None - - -@torch._dynamo.allow_in_graph -class NoopFwToFloat8BwStatic(torch.autograd.Function): - """ - Forward: no-op - Backward: convert to float8_e5m2 with static scaling - """ - - @staticmethod - def forward( - ctx, - tensor, - scale, - linear_mm_config: LinearMMConfig, - target_dtype: torch.dtype, - ): - ctx.save_for_backward(scale) - ctx.linear_mm_config = linear_mm_config - ctx.target_dtype = target_dtype - return tensor - - @staticmethod - def backward(ctx, gradY): - if tensor_already_casted_to_fp8(gradY): - return gradY, None, None, None - (gradY_scale,) = ctx.saved_tensors - fp8_tensor = hp_tensor_and_scale_to_float8( - gradY, - gradY_scale, - ctx.target_dtype, - ctx.linear_mm_config, - GemmInputRole.GRAD_OUTPUT, - ) - return fp8_tensor, None, None, None diff --git a/torchao/float8/float8_tensor_parallel.py b/torchao/float8/float8_tensor_parallel.py index a52b38b6bf..abc74e3ff6 100644 --- a/torchao/float8/float8_tensor_parallel.py +++ b/torchao/float8/float8_tensor_parallel.py @@ -27,8 +27,7 @@ def _float8_linear_supports_float8_allgather(m): - # TODO(future): add support for delayed scaling for activations - # and gradients + # TODO(future PR): also gate this by granularity return ( m.scaling_type_input == ScalingType.DYNAMIC and m.scaling_type_grad_output == ScalingType.DYNAMIC diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index 926b97edb8..625fb29235 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -4,13 +4,13 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple, Union import torch import torch.distributed as dist from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce -from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType +from torchao.float8.config import ScalingGranularity # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html @@ -53,44 +53,6 @@ def amax_to_scale( return res -@torch.no_grad() -def amax_history_to_scale( - amax_history: torch.Tensor, - float8_dtype: torch.Tensor, - history_to_scale_fn_type: Literal["max"], -): - """Takes in a history of amax values and returns a scale tensor. - Args: - amax_history: A tensor containing the history of amax values. - float8_dtype: The float8 dtype. - history_to_scale_fn_type: The type of function to use to convert the history to a scale. - """ - if history_to_scale_fn_type == "max": - amax = torch.max(amax_history) - return amax_to_scale(amax, float8_dtype) - raise NotImplementedError() - - -@torch.no_grad() -def amax_history_to_scale_stack( - amax_history: torch.Tensor, - float8_dtype: torch.dtype, - history_to_scale_fn_type: Literal["max"], -) -> torch.Tensor: - """Takes in a stack of amax_history tensors and returns a scale tensor. - Args: - amax_history: A 2D tensor containing a stack of amax histories. - float8_dtype: The float8 dtype. - history_to_scale_fn_type: The type of function to use to convert the history to a scale. - """ - if history_to_scale_fn_type == "max": - amax_stack = torch.max(amax_history, dim=1).values - return amax_to_scale(amax_stack, float8_dtype) - raise NotImplementedError( - f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}" - ) - - @torch.no_grad() def tensor_to_amax( x: torch.Tensor, @@ -274,17 +236,6 @@ def pad_tensor_for_matmul( return torch.nn.functional.pad(tensor, (0, pad_dim2, 0, pad_dim1)) -def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: - """ - Returns True if `config` has any delayed or static scaling, and False otherwise. - """ - return ( - config.cast_config_input.scaling_type != ScalingType.DYNAMIC - or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC - or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC - ) - - def _round_scale_down_to_power_of_2(scale: torch.Tensor): assert scale.dtype == torch.float32, "scale must be float32 tensor" return torch.exp2(torch.floor(torch.log2(scale))) diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index f246879a7c..7b24dc2b53 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -13,8 +13,6 @@ from torch._prims_common import suggest_memory_format from torchao.float8.float8_scaling_utils import ( - _maybe_initialize_amaxes_scales_for_float8_cast, - hp_tensor_to_float8_delayed, hp_tensor_to_float8_dynamic, ) from torchao.float8.float8_tensor import ( @@ -39,14 +37,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: """ from torch.distributed._tensor import DTensor - from torchao.float8.config import ScalingType from torchao.float8.float8_linear import Float8Linear - if any( - isinstance(m, Float8Linear) and m.scaling_type_weight is ScalingType.DELAYED - for m in module.modules() - ): - raise NotImplementedError("Only supports dynamic scaling") float8_linears: List[Float8Linear] = [ m for m in module.modules() @@ -274,331 +266,3 @@ def fsdp_post_all_gather( self._linear_mm_config, gemm_input_role=GemmInputRole.WEIGHT, ), (data,) - - -class WeightWithDelayedFloat8CastTensor(torch.Tensor): - @staticmethod - def __new__( - cls, - tensor: torch.Tensor, - amax_buffer: torch.Tensor, - amax_history_buffer: torch.Tensor, - scale_buffer: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - is_amax_initialized: bool, - ): - return torch.Tensor._make_wrapper_subclass( - cls, - tensor.size(), - strides=tensor.stride(), - storage_offset=tensor.storage_offset(), - memory_format=suggest_memory_format(tensor), - dtype=tensor.dtype, - layout=tensor.layout, - device=tensor.device, - pin_memory=tensor.is_pinned(), - requires_grad=tensor.requires_grad, - ) - - def __init__( - self, - tensor: torch.Tensor, - amax_buffer: torch.Tensor, - amax_history_buffer: torch.Tensor, - scale_buffer: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - is_amax_initialized: bool, - ): - self._tensor = tensor - self._amax_buffer = amax_buffer - self._amax_history_buffer = amax_history_buffer - self._scale_buffer = scale_buffer - self._linear_mm_config = linear_mm_config - self._dtype = dtype - - # Note: is_amax_initialized is not a buffer to avoid data dependent - # control flow visible to dynamo - # TODO(future PR): add serialization for this flag - self.is_amax_initialized = is_amax_initialized - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - if func == torch.ops.aten.detach.default: - return WeightWithDelayedFloat8CastTensor( - args[0]._tensor, - args[0]._amax_buffer, - args[0]._amax_history_buffer, - args[0]._scale_buffer, - args[0]._linear_mm_config, - args[0]._dtype, - args[0].is_amax_initialized, - ) - mm_config: Optional[LinearMMConfig] = None - dtype: Optional[torch.dtype] = None - amax_buffer: Optional[torch.Tensor] = None - amax_history_buffer: Optional[torch.Tensor] = None - scale_buffer: Optional[torch.Tensor] = None - is_amax_initialized: Optional[bool] = None - - def unwrap(t): - nonlocal mm_config - if mm_config is None: - mm_config = t._linear_mm_config - else: - assert t._linear_mm_config == mm_config - nonlocal dtype - if dtype is None: - dtype = t._dtype - else: - assert t._dtype == dtype - nonlocal amax_buffer - if amax_buffer is None: - amax_buffer = t._amax_buffer - nonlocal amax_history_buffer - if amax_history_buffer is None: - amax_history_buffer = t._amax_history_buffer - nonlocal scale_buffer - if scale_buffer is None: - scale_buffer = t._scale_buffer - nonlocal is_amax_initialized - if is_amax_initialized is None: - is_amax_initialized = t.is_amax_initialized - return t._tensor - - args, kwargs = pytree.tree_map_only( - WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {}) - ) - out = func(*args, **kwargs) - if func not in _ops_to_preserve_subclass: - return out - return pytree.tree_map_only( - torch.Tensor, - lambda x: WeightWithDelayedFloat8CastTensor( - x, - amax_buffer, - amax_history_buffer, - scale_buffer, - mm_config, - dtype, - is_amax_initialized, - ), - out, - ) - - def __tensor_flatten__(self): - return ( - [ - "_tensor", - "_amax_buffer", - "_amax_history_buffer", - "_scale_buffer", - ], - { - "mm_config": self._linear_mm_config, - "dtype": self._dtype, - "is_amax_initialized": self.is_amax_initialized, - }, - ) - - @staticmethod - def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): - return WeightWithDelayedFloat8CastTensor( - inner_tensors["_tensor"], - inner_tensors["_amax_buffer"], - inner_tensors["_amax_history_buffer"], - inner_tensors["_scale_buffer"], - metadata["mm_config"], - metadata["dtype"], - metadata["is_amax_initialized"], - ) - - def __repr__(self): - return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._linear_mm_config}, dtype={self._dtype})" - - def fsdp_pre_all_gather(self, mesh): - # initialize if needed - # TODO(before land): ensure settings are consistent between Float8Linear and here - if not self.is_amax_initialized: - _maybe_initialize_amaxes_scales_for_float8_cast( - self._tensor, - self._amax_buffer, - self._amax_history_buffer, - self._scale_buffer, - "max", # TODO(before land): read this from parent - self._dtype, - self.is_amax_initialized, - reduce_amax=True, - ) - self.is_amax_initialized = True - - float8_tensor = hp_tensor_to_float8_delayed( - self._tensor, - self._scale_buffer, - self._dtype, - self._amax_buffer, - self._linear_mm_config, - GemmInputRole.WEIGHT, - ) - return (float8_tensor._data,), (float8_tensor._scale,) - - def fsdp_post_all_gather( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, - *, - out: Optional[torch.Tensor] = None, - ): - (data,) = all_gather_outputs - (scale,) = metadata - if out is not None: - assert isinstance(out, Float8Tensor), f"{type(out)}" - out._scale = scale - return - return Float8Tensor( - data, - scale, - param_dtype, - self._linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ), (data,) - - -class WeightWithStaticFloat8CastTensor(torch.Tensor): - @staticmethod - def __new__( - cls, - tensor: torch.Tensor, - static_scale: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - ): - return torch.Tensor._make_wrapper_subclass( - cls, - tensor.size(), - strides=tensor.stride(), - storage_offset=tensor.storage_offset(), - memory_format=suggest_memory_format(tensor), - dtype=tensor.dtype, - layout=tensor.layout, - device=tensor.device, - pin_memory=tensor.is_pinned(), - requires_grad=tensor.requires_grad, - ) - - def __init__( - self, - tensor: torch.Tensor, - static_scale: torch.Tensor, - linear_mm_config: LinearMMConfig, - dtype: torch.dtype, - ): - self._tensor = tensor - self._static_scale = static_scale - self._linear_mm_config = linear_mm_config - self._dtype = dtype - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - if func == torch.ops.aten.detach.default: - return WeightWithStaticFloat8CastTensor( - args[0]._tensor, - args[0]._static_scale, - args[0]._linear_mm_config, - args[0]._dtype, - ) - static_scale: Optional[torch.Tensor] = None - mm_config: Optional[LinearMMConfig] = None - dtype: Optional[torch.dtype] = None - - def unwrap(t): - nonlocal static_scale - if static_scale is None: - static_scale = t._static_scale - nonlocal mm_config - if mm_config is None: - mm_config = t._linear_mm_config - else: - assert t._linear_mm_config == mm_config - nonlocal dtype - if dtype is None: - dtype = t._dtype - else: - assert t._dtype == dtype - return t._tensor - - args, kwargs = pytree.tree_map_only( - WeightWithStaticFloat8CastTensor, unwrap, (args, kwargs or {}) - ) - out = func(*args, **kwargs) - if func not in _ops_to_preserve_subclass: - return out - return pytree.tree_map_only( - torch.Tensor, - lambda x: WeightWithStaticFloat8CastTensor( - x, static_scale, mm_config, dtype - ), - out, - ) - - def __tensor_flatten__(self): - return ["_tensor", "_static_scale"], { - "mm_config": self._linear_mm_config, - "dtype": self._dtype, - } - - @staticmethod - def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - return WeightWithStaticFloat8CastTensor( - inner_tensors["_tensor"], - inner_tensors["_static_scale"], - flatten_spec["mm_config"], - flatten_spec["dtype"], - ) - - def __repr__(self): - return f"WeightWithStaticFloat8CastTensor(tensor={self._tensor}, static_scale={self._static_scale}, linear_mm_config={self._linear_mm_config}, dtype={self.dtype})" - - def fsdp_pre_all_gather(self, mesh): - float8_tensor = hp_tensor_and_scale_to_float8( - self._tensor, - self._static_scale, - self._dtype, - self._linear_mm_config, - GemmInputRole.WEIGHT, - ) - return (float8_tensor._data,), (float8_tensor._scale,) - - def fsdp_post_all_gather( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, - *, - out: Optional[torch.Tensor] = None, - ): - (data,) = all_gather_outputs - (scale,) = metadata - if out is not None: - from torch.distributed._tensor import DTensor - - if isinstance(out, Float8Tensor): - out._scale = scale - elif isinstance(out, DTensor) and isinstance( - out._local_tensor, Float8Tensor - ): - out._local_tensor._scale = scale - else: - raise RuntimeError( - f"out must be a Float8Tensor or DTensor(_local_tensor=Float8Tensor), but got {out}" - ) - return - return Float8Tensor( - data, - scale, - param_dtype, - self._linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ), (data,) diff --git a/torchao/float8/inductor_utils.py b/torchao/float8/inductor_utils.py deleted file mode 100644 index 3e86202536..0000000000 --- a/torchao/float8/inductor_utils.py +++ /dev/null @@ -1,126 +0,0 @@ -import functools -import inspect -import traceback -from collections import deque - -import torch - - -def amax_with_scaling_pattern(tensor_x_inp, scale_x, fp8_dtype, fp8_max): - tensor_x = tensor_x_inp.to(torch.float32) * scale_x - tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max) - tensor_x = tensor_x.to(fp8_dtype) - amax = torch.max(torch.abs(tensor_x_inp)) - return (tensor_x, amax) - - -def amax_with_scaling_tiled_replacement(tensor_x_inp, scale_x, fp8_dtype, fp8_max): - tensor_x = tensor_x_inp.to(torch.float32) * scale_x - tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max) - tensor_x = tensor_x.to(fp8_dtype) - amax_1 = torch.max(torch.abs(tensor_x_inp), dim=-1).values - amax = torch.max(amax_1) - return (tensor_x, amax) - - -# The amax_with_scaling_pattern will also match dynamic scaling cases, we want to avoid that. -# `scale_x` of delayed scaling comes from the previous iteration, instead of from `tensor_x_inp`. -# We check that `scale_x` is not a dependency of `tensor_x_inp` -def fp8_delayed_scaling_extra_check(match): - scale_x_inputs = deque([match.kwargs["scale_x"]]) - max_num_node_to_check = 20 # Don't traverse too many nodes - current_num_node = 0 - while len(scale_x_inputs) > 0 and current_num_node < max_num_node_to_check: - current_node = scale_x_inputs.popleft() - for n in current_node.all_input_nodes: - if n == match.kwargs["tensor_x_inp"]: - return False - scale_x_inputs.append(n) - current_num_node += 1 - return True - - -def partialize_and_update_signature(func, **kwargs): - """ - Equivalent to functools.partial but also updates the signature on returned function - """ - original_sig = inspect.signature(func) - parameters = original_sig.parameters - - new_parameters = { - key: value for key, value in parameters.items() if key not in kwargs - } - new_sig = inspect.Signature(parameters=list(new_parameters.values())) - - partial_func = functools.partial(func, **kwargs) - - def wrapper(*args, **kwargs): - return partial_func(*args, **kwargs) - - wrapper.__signature__ = new_sig # type: ignore[attr-defined] - wrapper.__name__ = func.__name__ - - return wrapper - - -def register_fp8_delayed_scaling_patterns_inner(): - from torch._inductor.fx_passes.post_grad import ( - pass_patterns as post_grad_patterns_all, - ) - from torch._inductor.pattern_matcher import fwd_only, register_replacement - - post_grad_patterns = post_grad_patterns_all[1] # medium priority - - if torch.cuda.is_available(): - for fp8_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, - torch.float8_e4m3fnuz, - torch.float8_e5m2fnuz, - ]: - # torch.float16 has the same pattern as torch.bfloat16, because they both needs `tensor_x_inp.to(torch.float32)` - for dtype in [torch.float32, torch.bfloat16]: - device = "cuda" - register_replacement( - partialize_and_update_signature( - amax_with_scaling_pattern, - fp8_dtype=fp8_dtype, - fp8_max=torch.finfo(fp8_dtype).max, - ), - partialize_and_update_signature( - amax_with_scaling_tiled_replacement, - fp8_dtype=fp8_dtype, - fp8_max=torch.finfo(fp8_dtype).max, - ), - [ - torch.tensor((16, 16), device=device, dtype=dtype), - torch.tensor(2.0, device=device, dtype=torch.float32), - ], - fwd_only, - post_grad_patterns, - extra_check=fp8_delayed_scaling_extra_check, - ) - - -""" -This a short-term workaround of the delayed scaling performance issue. -It explicitly replaces `max(x)` with `max(max(x, dim=-1))`, enabling the fusion of amax scaling factor calculation and fp8 casting. - -Usage: - To use this solution, add the following line at the beginning of your user code: - torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes() -""" - - -def _prototype_register_float8_delayed_scaling_inductor_passes() -> None: - # To make the fp8 delayed scaling pattern work, we need a fix pr from inductor, https://github.com/pytorch/pytorch/pull/139321 - # Will throw the error if the pattern registration did not work, up to user to decide what to do with it - try: - register_fp8_delayed_scaling_patterns_inner() - except AssertionError as e: - if "assert pattern_repr not in _seen_patterns" in traceback.format_exc(): - print( - f"Caught duplicated patterns in register_fp8_delayed_scaling_patterns: {traceback.format_exc()}", - "\nPlease update your pytorch dependency to the latest main branch to fix it.\n", - ) - raise e diff --git a/torchao/float8/roofline_utils.py b/torchao/float8/roofline_utils.py index 16cf847fe2..58c84c5fa6 100644 --- a/torchao/float8/roofline_utils.py +++ b/torchao/float8/roofline_utils.py @@ -38,78 +38,30 @@ def get_tensor_memory_traffic_bytes( # assumes input bf16, output f8 numel = dim0 * dim1 - if scaling_type == "dynamic": - # x_bf16 = ... - # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp - # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs - # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8 - - if fuse_with_prev: - kernel_1_rw = 0 - else: - # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) - kernel_1_rw = BYTES_PER_EL_BF16 * numel - - # kernel 3: read in bf16, write twice in float8 (row-major and col-major) - kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel - - if model_torch_compile_limitations: - # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) - # has an extra memory read of the input in fp8 - # context: https://github.com/pytorch/pytorch/issues/130015 - tc_adjustment = numel * BYTES_PER_EL_FLOAT8 - else: - tc_adjustment = 0 - - return kernel_1_rw + kernel_3_rw + tc_adjustment + assert scaling_type == "dynamic", "unsupported" + # x_bf16 = ... + # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp + # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs + # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8 + + if fuse_with_prev: + kernel_1_rw = 0 + else: + # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) + kernel_1_rw = BYTES_PER_EL_BF16 * numel + + # kernel 3: read in bf16, write twice in float8 (row-major and col-major) + kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel + if model_torch_compile_limitations: + # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) + # has an extra memory read of the input in fp8 + # context: https://github.com/pytorch/pytorch/issues/130015 + tc_adjustment = numel * BYTES_PER_EL_FLOAT8 else: - assert scaling_type == "delayed", "unsupported" - # x_bf16 = ... - # kernel 1: x_bf16 -> max_abs_stage_1_and_to_float8 -> x_float8, tmp - # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs - # kernel 3 (not modeled): scale -> reciprocal -> inv_scale - - if fuse_with_prev: - kernel_1_r = 0 - else: - kernel_1_r = numel * BYTES_PER_EL_BF16 - # write twice: once in row major, once in col-major - kernel_1_w = numel * BYTES_PER_EL_FLOAT8 * 2 - - if model_torch_compile_limitations: - # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...) - # has an extra memory read of the input in fp8 - # context: https://github.com/pytorch/pytorch/issues/130015 - tc_adjustment = numel * BYTES_PER_EL_FLOAT8 - - # https://github.com/pytorch/pytorch/issues/128063 - # instead of - # kernel 1: x_bf16 -> max(abs(x)), x_fp8 - # kernel 2: not modeled - # kernel 3: not modeled - # we get - # kernel 1: x_bf16 -> max(abs(x)) - # reads: same as before - # writes: 0 - # ... - # kernel 4: x_bf16, scale -> x_fp8 - # reads: numel * BYTES_PER_EL_BF16 - # writes: 2 * numel * BYTES_PER_EL_FLOAT8 - # Note that assuming worst case, this issue brings the memory - # traffic for delayed scaling to be equal to that of dynamic scaling. - tc_adjustment += ( - # subtract writes from kernel 1 - -1 * 2 * numel * BYTES_PER_EL_FLOAT8 - # add reads for kernel 4 - + numel * BYTES_PER_EL_BF16 - # add writes for kernel 4 - + 2 * numel * BYTES_PER_EL_FLOAT8 - ) - else: - tc_adjustment = 0 - - return kernel_1_r + kernel_1_w + tc_adjustment + tc_adjustment = 0 + + return kernel_1_rw + kernel_3_rw + tc_adjustment def get_gemm_time_sympy(M, K, N, dtype): @@ -131,9 +83,9 @@ def get_float8_mem_sympy( scaling_type_weight: str = "dynamic", scaling_type_grad_output: str = "dynamic", ): - assert scaling_type_input in ("dynamic", "delayed"), "unsupported" - assert scaling_type_weight in ("dynamic", "delayed"), "unsupported" - assert scaling_type_grad_output in ("dynamic", "delayed"), "unsupported" + assert scaling_type_input in ("dynamic",), "unsupported" + assert scaling_type_weight in ("dynamic",), "unsupported" + assert scaling_type_grad_output in ("dynamic",), "unsupported" # there are three gemms in the fwd/bwd of a linear: # @@ -207,27 +159,12 @@ def get_float8_mem_sympy( if scaling_type_input == "dynamic": # second stage of max-abs reduction num_extra_kernels += 1 - elif scaling_type_input == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 if scaling_type_weight == "dynamic": # second stage of max-abs reduction num_extra_kernels += 1 - elif scaling_type_weight == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 if scaling_type_grad_output == "dynamic": # second stage of max-abs reduction num_extra_kernels += 1 - elif scaling_type_grad_output == "delayed": - # second stage of max-abs reduction - num_extra_kernels += 1 - # reciprocal of scale - num_extra_kernels += 1 extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC diff --git a/torchao/float8/stateful_float8_linear.py b/torchao/float8/stateful_float8_linear.py deleted file mode 100644 index ac01803e0b..0000000000 --- a/torchao/float8/stateful_float8_linear.py +++ /dev/null @@ -1,439 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -""" -Stateful version of Float8Linear, created to keep Float8Linear simple and -only require code readers to read the stateful code if they care about delayed -or static scaling. -""" - -from typing import Optional - -import torch -import torch.utils.checkpoint as checkpoint - -from torchao.float8.config import Float8LinearConfig, ScalingType -from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 -from torchao.float8.float8_linear import ( - Float8Linear, -) -from torchao.float8.float8_scaling_utils import ( - NoopFwToFloat8BwDelayed, - NoopFwToFloat8BwDynamic, - NoopFwToFloat8BwStatic, - _maybe_initialize_amaxes_scales_for_float8_cast, - hp_tensor_to_float8_delayed, - hp_tensor_to_float8_dynamic, - hp_tensor_to_float8_static, -) -from torchao.float8.float8_tensor import ( - GemmInputRole, - hp_tensor_and_scale_to_float8, -) -from torchao.float8.float8_utils import ( - tensor_to_amax, - tensor_to_scale, -) -from torchao.float8.fsdp_utils import ( - WeightWithDelayedFloat8CastTensor, - WeightWithDynamicFloat8CastTensor, - WeightWithStaticFloat8CastTensor, -) - - -@torch._dynamo.allow_in_graph -class manual_float8_matmul_with_args_in_float8(torch.autograd.Function): - """ - Like torch.matmul, but with the arguments in float8 - - Note: this function requires all arguments to already be Float8Tensor objects, - which only supports tensorwise scaling granularity. The reason we didn't just make this - function support axiswise scaling granularity is because that would need very - careful testing of delayed scaling, as delayed scaling modifies buffers inplace. - - In the future we'll probably have to unify, just postponing that until a future PR. - """ - - @staticmethod - def forward( - ctx, - input_fp8, - weight_fp8_t, - ): - ctx.save_for_backward(input_fp8, weight_fp8_t) - # the reshapes are needed in order to make the shapes compatible with - # torch.mm - orig_shape = input_fp8.shape - input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) - res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) - res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) - return res_bits - - @staticmethod - def backward(ctx, grad_output_fp8): - input_fp8, weight_fp8_t = ctx.saved_tensors - - # the reshapes are needed in order to make the shapes compatible with - # torch.mm - grad_output_fp8_orig_shape = grad_output_fp8.shape - grad_output_fp8_reshaped = grad_output_fp8.reshape( - -1, grad_output_fp8_orig_shape[-1] - ) - - # calculate grad_input - grad_input = torch.mm( - grad_output_fp8_reshaped, - weight_fp8_t.t(), - ) - grad_input = grad_input.reshape( - *grad_output_fp8_orig_shape[:-1], grad_input.shape[-1] - ) - - input_fp8_orig_shape = input_fp8.shape - input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1]) - - # calculate grad_weight - # Note: the variant below is slightly faster on LLaMa 3 8B pretraining - # compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped` - grad_weight = torch.mm( - grad_output_fp8_reshaped.t(), - input_fp8_reshaped, - ) - - return grad_input, grad_weight.t() - - -class StatefulFloat8Linear(Float8Linear): - def __init__(self, *args, **kwargs): - # Amax scales should always be kept as float32. - self.always_float32_buffers = set() - - super().__init__(*args, **kwargs) - - # Convenience flag to skip code related to delayed scaling - self.has_any_delayed_scaling = ( - self.scaling_type_input is ScalingType.DELAYED - or self.scaling_type_weight is ScalingType.DELAYED - or self.scaling_type_grad_output is ScalingType.DELAYED - ) - - self.create_buffers() - - # Note: is_amax_initialized is not a buffer to avoid data dependent - # control flow visible to dynamo - # TODO(future PR): add serialization for this flag - self.is_amax_initialized = not self.config.enable_amax_init - - # pre_forward and post_forward are currently broken with FSDP - # and torch.compile, this option can disable them - # Note that when using `self.config.enable_pre_and_post_forward = False`, - # it's recommended to also set `self.config.enable_amax_init = False`. - # Otherwise, the amax buffer would never be marked as initialized and - # would be initialized in every iteration. - self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward - - def create_buffers(self): - # Default values for history buffers, see above TODO - history_len = self.config.delayed_scaling_config.history_len - device = self.weight.device - default_input = torch.finfo(self.config.cast_config_input.target_dtype).max - default_weight = torch.finfo(self.config.cast_config_weight.target_dtype).max - default_grad_output = torch.finfo( - self.config.cast_config_grad_output.target_dtype - ).max - - # Note: for now, create all the buffers if any are needed, to postpone - # the work to make the scale and amax syncing and history calculation - # handle a heterogeneous setup. We can do that work later if benchmarks - # show it is worth doing. - if self.has_any_delayed_scaling: - self.register_always_float32_buffer( - "fp8_amax_input", torch.tensor([default_input], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_history_input", torch.zeros(history_len, device=device) - ) - self.register_always_float32_buffer( - "fp8_scale_input", torch.tensor([1.0], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_weight", torch.tensor([default_weight], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_history_weight", torch.zeros(history_len, device=device) - ) - self.register_always_float32_buffer( - "fp8_scale_weight", torch.tensor([1.0], device=device) - ) - self.register_always_float32_buffer( - "fp8_amax_grad_output", - torch.tensor([default_grad_output], device=device), - ) - self.register_always_float32_buffer( - "fp8_amax_history_grad_output", torch.zeros(history_len, device=device) - ) - self.register_always_float32_buffer( - "fp8_scale_grad_output", torch.tensor([1.0], device=device) - ) - - if self.config.cast_config_input.static_scale is not None: - self.register_always_float32_buffer( - "fp8_static_scale_input", - self.config.cast_config_input.static_scale.to(device), - ) - if self.config.cast_config_weight.static_scale is not None: - self.register_always_float32_buffer( - "fp8_static_scale_weight", - self.config.cast_config_weight.static_scale.to(device), - ) - if self.config.cast_config_grad_output.static_scale is not None: - self.register_always_float32_buffer( - "fp8_static_scale_grad_output", - self.config.cast_config_grad_output.static_scale.to(device), - ) - - def register_always_float32_buffer( - self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True - ) -> None: - self.register_buffer(name=name, tensor=tensor, persistent=persistent) - self.always_float32_buffers.add(name) - - def _apply(self, fn, recurse=True): - ret = super()._apply(fn, recurse) - self.convert_amax_buffer_to_float32() - return ret - - def convert_amax_buffer_to_float32(self): - for key in self.always_float32_buffers: - if self._buffers[key] is not None: - self._buffers[key] = self._buffers[key].to(torch.float32) - - def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor: - is_amax_initialized = self.is_amax_initialized - # Duplicate the autocast logic for F.linear, so that the output - # of our module has the right original precision - if torch.is_autocast_enabled(): - # For now, hardcode to GPU's autocast dtype - # if we need CPU support in the future, we can add it - autocast_dtype = torch.get_autocast_gpu_dtype() - input = input.to(autocast_dtype) - - if tensor_already_casted_to_fp8(input): - input_fp8 = input - elif self.scaling_type_input is ScalingType.DELAYED: - scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - _maybe_initialize_amaxes_scales_for_float8_cast( - input, - self.fp8_amax_input, - self.fp8_amax_history_input, - self.fp8_scale_input, - scale_fn_name, - self.config.cast_config_input.target_dtype, - is_amax_initialized, - reduce_amax=True, - ) - input_fp8 = hp_tensor_to_float8_delayed( - input, - self.fp8_scale_input, - self.config.cast_config_input.target_dtype, - self.fp8_amax_input, - linear_mm_config=self.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - ) - elif self.scaling_type_input is ScalingType.DYNAMIC: - input_fp8 = hp_tensor_to_float8_dynamic( - input, - self.config.cast_config_input.target_dtype, - self.linear_mm_config, - gemm_input_role=GemmInputRole.INPUT, - ) - else: - assert self.scaling_type_input is ScalingType.STATIC - input_fp8 = hp_tensor_to_float8_static( - input, - self.fp8_static_scale_input, - self.config.cast_config_input.target_dtype, - self.linear_mm_config, - ) - - return input_fp8 - - def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: - if tensor_already_casted_to_fp8(weight): - return None - if self.scaling_type_weight is ScalingType.DELAYED: - scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - _maybe_initialize_amaxes_scales_for_float8_cast( - weight, - self.fp8_amax_weight, - self.fp8_amax_history_weight, - self.fp8_scale_weight, - scale_fn_name, - self.config.cast_config_weight.target_dtype, - self.is_amax_initialized, - reduce_amax=True, - ) - self.fp8_amax_weight.fill_(tensor_to_amax(weight)) - return self.fp8_scale_weight - elif self.scaling_type_weight is ScalingType.DYNAMIC: - return tensor_to_scale(weight, self.config.cast_config_weight.target_dtype) - else: - assert self.scaling_type_weight is ScalingType.STATIC - return self.fp8_static_scale_weight - - def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: - if self.scaling_type_grad_output is ScalingType.DELAYED: - scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - output = NoopFwToFloat8BwDelayed.apply( - output, - self.fp8_amax_grad_output, - self.fp8_amax_history_grad_output, - self.fp8_scale_grad_output, - scale_fn_name, - self.is_amax_initialized, - self.linear_mm_config, - self.config.cast_config_grad_output.target_dtype, - ) - elif self.scaling_type_grad_output is ScalingType.DYNAMIC: - output = NoopFwToFloat8BwDynamic.apply( - output, - self.linear_mm_config, - self.config.cast_config_grad_output.target_dtype, - ) - else: - assert self.scaling_type_grad_output is ScalingType.STATIC - output = NoopFwToFloat8BwStatic.apply( - output, - self.fp8_static_scale_grad_output, - self.linear_mm_config, - self.config.cast_config_grad_output.target_dtype, - ) - return output - - def cast_weight_to_float8_t( - self, - weight: torch.Tensor, - weight_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if tensor_already_casted_to_fp8(weight): - return weight.t() - weight_fp8 = hp_tensor_and_scale_to_float8( - weight, - weight_scale, - self.config.cast_config_weight.target_dtype, - self.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) - return weight_fp8.t() - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.has_any_delayed_scaling: - self.float8_pre_forward(input) - - input_fp8 = self.cast_input_to_float8(input) - # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, - # weight_scale should be saved. - weight_scale = self.get_weight_scale(self.weight) - - if self.config.force_recompute_fp8_weight_in_bwd: - weight_fp8_t = checkpoint.checkpoint( - self.cast_weight_to_float8_t, - self.weight, - weight_scale, - ) - else: - weight_fp8_t = self.cast_weight_to_float8_t(self.weight, weight_scale) - - output = manual_float8_matmul_with_args_in_float8.apply(input_fp8, weight_fp8_t) - - # Cast grad_output to float8_e5m2 during backward - output = self.cast_output_to_float8_in_bw(output) - - if self.bias is not None: - output = output + self.bias.to(output.dtype) - - if self.has_any_delayed_scaling: - self.float8_post_forward() - return output - - def float8_pre_forward(self, input): - # TODO(future PR): deprecate these functions and the corresponding - # config setting - if not self.enable_pre_and_post_forward: - return - - def float8_post_forward(self): - # TODO(future PR): deprecate these functions and the corresponding - # config setting - if not self.enable_pre_and_post_forward: - return - - @classmethod - def from_float( - cls, - mod, - config: Optional[Float8LinearConfig] = None, - ): - """ - Create an nn.Linear with fp8 compute from a regular nn.Linear - - Args: - mod (torch.nn.Linear): nn.Linear to convert - config (Optional[Float8LinearConfig]): configuration for conversion to float8 - """ - if config is None: - config = Float8LinearConfig() - with torch.device("meta"): - new_mod = cls( - mod.in_features, - mod.out_features, - bias=False, - config=config, - ) - new_mod.weight = mod.weight - new_mod.bias = mod.bias - # need to create buffers again when moving from meta device to - # real device - new_mod.create_buffers() - - # If FSDP float8 all-gather is on, wrap the weight in a float8-aware - # tensor subclass. This must happen last because: - # 1. weight needs to be on the correct device to create the buffers - # 2. buffers need to be already created for the delayed scaling version - # of the weight wrapper to be initialized - if config.enable_fsdp_float8_all_gather: - if config.cast_config_weight.scaling_type is ScalingType.DYNAMIC: - new_mod.weight = torch.nn.Parameter( - WeightWithDynamicFloat8CastTensor( - new_mod.weight, - new_mod.linear_mm_config, - new_mod.config.cast_config_weight.target_dtype, - ) - ) - elif config.cast_config_weight.scaling_type is ScalingType.DELAYED: - new_mod.weight = torch.nn.Parameter( - WeightWithDelayedFloat8CastTensor( - new_mod.weight, - new_mod.fp8_amax_weight, - new_mod.fp8_amax_history_weight, - new_mod.fp8_scale_weight, - new_mod.linear_mm_config, - new_mod.config.cast_config_weight.target_dtype, - new_mod.is_amax_initialized, - ) - ) - else: - assert config.cast_config_weight.scaling_type is ScalingType.STATIC - new_mod.weight = torch.nn.Parameter( - WeightWithStaticFloat8CastTensor( - new_mod.weight, - new_mod.fp8_static_scale_weight, - new_mod.linear_mm_config, - new_mod.config.cast_config_weight.target_dtype, - ) - ) - - return new_mod diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/float8/fsdp2_utils.py index a059b4d2a9..31a5cf8db0 100644 --- a/torchao/testing/float8/fsdp2_utils.py +++ b/torchao/testing/float8/fsdp2_utils.py @@ -8,10 +8,6 @@ Float8LinearConfig, ScalingType, ) -from torchao.float8.float8_linear_utils import ( - linear_requires_sync, - sync_float8_amax_and_scale_history, -) from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp @@ -38,9 +34,6 @@ def check_parity_no_mp( dist.all_reduce(param.grad) param.grad.div_(dist.get_world_size()) - if linear_requires_sync(config): - sync_float8_amax_and_scale_history(model) - optim.step() if ( model is fsdp_model @@ -82,7 +75,6 @@ def check_parity_bf16_mp( param_bf16.grad.div_(dist.get_world_size()) param_fp32.grad = param_bf16.grad.float() param_bf16.grad = None - # TODO(future): add amax syncing once delayed scaling is supported optim.step() for param_fp32, param_bf16 in zip( ref_model.parameters(), ref_model_bf16.parameters() diff --git a/torchao/testing/float8/test_utils.py b/torchao/testing/float8/test_utils.py index 7b8ac121b6..2da34f53ed 100644 --- a/torchao/testing/float8/test_utils.py +++ b/torchao/testing/float8/test_utils.py @@ -1,9 +1,6 @@ -import torch - from torchao.float8.config import ( CastConfig, Float8LinearConfig, - ScalingType, ) @@ -13,32 +10,14 @@ def get_test_float8_linear_config( scaling_type_grad_output, emulate: bool, ): - static_scale_one = torch.tensor([1.0], device="cuda") - - if scaling_type_input is ScalingType.STATIC: - static_scale_input = static_scale_one - else: - static_scale_input = None - if scaling_type_weight is ScalingType.STATIC: - static_scale_weight = static_scale_one - else: - static_scale_weight = None - if scaling_type_grad_output is ScalingType.STATIC: - static_scale_grad_output = static_scale_one - else: - static_scale_grad_output = None - cast_config_input = CastConfig( scaling_type=scaling_type_input, - static_scale=static_scale_input, ) cast_config_weight = CastConfig( scaling_type=scaling_type_weight, - static_scale=static_scale_weight, ) cast_config_grad_output = CastConfig( scaling_type=scaling_type_grad_output, - static_scale=static_scale_grad_output, ) config = Float8LinearConfig( From 2a3fbffc461f30751552006c864c57a80b297ca6 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Sat, 22 Feb 2025 08:49:34 -0800 Subject: [PATCH 109/115] MX Updated to_blocked to not call nn.pad (#1762) stack-info: PR: https://github.com/pytorch/ao/pull/1762, branch: drisspg/stack/38 --- torchao/prototype/mx_formats/utils.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 4cdc26109d..8b186f82d6 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import torch -import torch.nn.functional as F Tensor = torch.Tensor @@ -31,14 +30,23 @@ def to_blocked(input_matrix) -> Tensor: n_row_blocks = ceil_div(rows, 128) n_col_blocks = ceil_div(cols, 4) - # Pad out and view as tiles of (128, 4) - padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128)) - blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix - # rearrange all tiles + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - # Layout rearranged tiles according to second pic return rearranged.flatten() From 8d3881448cc47d9005a55d5f930db3091659366d Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 24 Feb 2025 09:42:47 -0800 Subject: [PATCH 110/115] add MX support to lowp training profiling script (#1765) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- ...ear_float8.py => profile_lowp_training.py} | 192 ++++++++++-------- benchmarks/float8/utils.py | 35 +--- torchao/prototype/mx_formats/config.py | 38 +++- 3 files changed, 157 insertions(+), 108 deletions(-) rename benchmarks/float8/{profile_linear_float8.py => profile_lowp_training.py} (77%) diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_lowp_training.py similarity index 77% rename from benchmarks/float8/profile_linear_float8.py rename to benchmarks/float8/profile_lowp_training.py index e28ed6dcc2..ab242f4051 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_lowp_training.py @@ -4,6 +4,9 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +# This is a convenience script to profile fwd+bwd of individual layers with +# float8 training or mx training on a single GPU. + import copy import functools import io @@ -38,12 +41,13 @@ from torchao.float8.config import ( Float8LinearConfig, - ScalingType, ) from torchao.float8.float8_linear_utils import ( convert_to_float8_training, ) -from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.prototype.mx_formats.config import MXLinearConfig +from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear +from torchao.prototype.mx_formats.mx_tensor import MXTensor # don't truncate long kernel names pd.options.display.max_colwidth = 100 @@ -257,7 +261,6 @@ def profile_function( # set up AC for max(abs(tensor)) # context: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.create_selective_checkpoint_contexts ops_to_save = [ - torch.ops.aten.abs.default, torch.ops.aten.max.default, ] @@ -275,14 +278,14 @@ def policy_fn(ctx, op, *args, **kwargs): def main( profile_path_prefix: pathlib.Path, compile: bool = True, - scaling_type_input: str = "dynamic", - scaling_type_weight: str = "dynamic", - scaling_type_grad_output: str = "dynamic", - recipe_name: Optional[str] = None, + float8_recipe_name: Optional[str] = None, + mx_recipe_name: Optional[str] = None, model_type: str = "linear", - dtype_filter: str = "both", - add_inductor_metadata_to_trace: bool = True, + experiment_filter: str = "both", + add_inductor_metadata_to_trace: bool = False, enable_activation_checkpointing: bool = False, + mode_filter: str = "fwd_bwd", + forward_only: bool = False, ): assert model_type in ( "linear", @@ -290,35 +293,37 @@ def main( "norm_ffn_norm", "norm_ffn_norm_small", ), "unsupported" - assert dtype_filter in ("both", "float8", "bfloat16") - - scaling_type_input = ScalingType(scaling_type_input) - scaling_type_weight = ScalingType(scaling_type_weight) - scaling_type_grad_output = ScalingType(scaling_type_grad_output) - - if recipe_name is None: - config = get_test_float8_linear_config( - scaling_type_input, - scaling_type_weight, - scaling_type_grad_output, - emulate=False, - ) - elif recipe_name is not None: - config = Float8LinearConfig.from_recipe_name(recipe_name) - - scaling_repr = "_".join( - [ - s.short_str() - for s in (scaling_type_input, scaling_type_weight, scaling_type_grad_output) - ] - ) + assert experiment_filter in ( + "both", + "lowp", + "ref", + ), "experiment_filter must be one of `both`, `lowp`, `ref`" + assert mode_filter in ( + "fwd_bwd", + "fwd", + "cast_only", + ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`" + if mode_filter == "cast_only": + assert experiment_filter == "lowp", "unsupported" + + assert not ( + float8_recipe_name is not None and mx_recipe_name is not None + ), "either float8_recipe_name or mx_recipe_name can be specified, but not both" + + if float8_recipe_name is None and mx_recipe_name is None: + config = Float8LinearConfig() + elif float8_recipe_name is not None: + config = Float8LinearConfig.from_recipe_name(float8_recipe_name) + elif mx_recipe_name is not None: + config = MXLinearConfig.from_recipe_name(mx_recipe_name) print(f"Compile is set to | {compile}") print(f"model_type is set to | {model_type}") - print(f"scaling_repr is set to | {scaling_repr}") print( f"enable_activation_checkpointing is set to {enable_activation_checkpointing}" ) + print(f"mode_filter is set to {mode_filter}") + print(f"config: {config}") device = "cuda" ref_dtype = torch.bfloat16 @@ -359,36 +364,58 @@ def main( m_ref = m_ref.to(device).to(ref_dtype) - m_float8 = copy.deepcopy(m_ref) - convert_to_float8_training(m_float8, config=config) + # get gradient shape + with torch.no_grad(): + _ = m_ref(input_tensor) + grad_output = torch.ones_like(_) + + m_lowp = copy.deepcopy(m_ref) + if mx_recipe_name is None: + convert_to_float8_training(m_lowp, config=config) + else: + swap_linear_with_mx_linear(m_lowp, config=config) + + # this function is only used for cast_only + to_mx_func = MXTensor.to_mx + + print("m_ref", m_ref) + print("m_lowp", m_lowp) + print("input_tensor.shape", input_tensor.shape) + print("grad_output.shape", grad_output.shape) + print() def ref_forw_backward(x): + assert mode_filter != "cast_only", "unsupported" if enable_activation_checkpointing: out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn) else: out = m_ref(x) - out.sum().backward() + if mode_filter == "fwd_bwd": + out.backward(grad_output) + + def lowp_forw_backward_wrapper(x): + if mode_filter == "cast_only": + # just cast and return early + _input_tensor_mx = to_mx_func( + input_tensor, + config.elem_dtype, + config.block_size, + gemm_kernel_choice=config.gemm_kernel_choice, + ) + return - def float8_forw(x): if enable_activation_checkpointing: - out = checkpoint(m_float8, x, use_reentrant=False, context_fn=context_fn) + out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn) else: - out = m_float8(x) - return out - - def float8_forw_backward_wrapper(x): - # TODO(future PR): this wrapper is for delayed scaling, we can clean it - # up now that delayed scaling is deprecated. - out = float8_forw(x) - - # out.sum().backward() is also not torch.compile fullgraph - # friendly - with record_function("backward"): - out.sum().backward() + out = m_lowp(x) + if mode_filter == "fwd_bwd": + with record_function("backward"): + out.backward(grad_output) if compile: m_ref = torch.compile(m_ref, fullgraph=True) - float8_forw = torch.compile(float8_forw, fullgraph=True) + m_lowp = torch.compile(m_lowp, fullgraph=True) + to_mx_func = torch.compile(to_mx_func, fullgraph=True) # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output # to populate triton kernel bandwidth further down in the script @@ -398,15 +425,21 @@ def float8_forw_backward_wrapper(x): else: f = io.StringIO() context = redirect_stdout(f) + + # if we are skipping forward, enable torch.no_grad() + maybe_no_grad_context = ( + torch.no_grad() if mode_filter != "fwd_bwd" else nullcontext() + ) + try: - with context: + with context, maybe_no_grad_context: profile_iters = 5 - ref_times, float8_times = None, None + ref_times, lowp_times = None, None data = [] num_leaf_tensors = 1 + len(list(m_ref.parameters())) - if dtype_filter != "float8": + if experiment_filter != "lowp": # Profile Reference Model print("profiling ref") ref_trace_suffix = f"_{model_type}_ref_compile_{compile}.json" @@ -452,50 +485,46 @@ def float8_forw_backward_wrapper(x): ] ) - if dtype_filter != "bfloat16": - # Profile Float8 Model - print("profiling float8") - float8_trace_suffix = ( - f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json" - ) - float8_log_suffix = ( - f"_{model_type}_float8_compile_{compile}_{scaling_repr}.txt" - ) - trace_float8_path = profile_path_prefix + float8_trace_suffix - log_float8_path = profile_path_prefix + float8_log_suffix - trace_float8_modified_path = trace_float8_path.replace( + if experiment_filter != "ref": + # Profile lowp Model + print("profiling lowp") + lowp_trace_suffix = f"_{model_type}_lowp_compile_{compile}.json" + lowp_log_suffix = f"_{model_type}_lowp_compile_{compile}.txt" + trace_lowp_path = profile_path_prefix + lowp_trace_suffix + log_lowp_path = profile_path_prefix + lowp_log_suffix + trace_lowp_modified_path = trace_lowp_path.replace( ".json", "_modified.json" ) profile_config = ProfileConfig( - trace_float8_path, - log_float8_path, - trace_float8_modified_path, - float8_trace_suffix, + trace_lowp_path, + log_lowp_path, + trace_lowp_modified_path, + lowp_trace_suffix, iters=profile_iters, warmup_iters=2, sync=True, ) p = profile_function( profile_config, - float8_forw_backward_wrapper, + lowp_forw_backward_wrapper, add_inductor_metadata_to_trace, input_tensor, ) - print(f"saved profiling trace to {trace_float8_path}") + print(f"saved profiling trace to {trace_lowp_path}") if add_inductor_metadata_to_trace: - print(f"saved torch logs to {log_float8_path}") - print(f"saved modified trace to {trace_float8_modified_path}") - float8_times = profiler_output_to_filtered_time_by_kernel_name( + print(f"saved torch logs to {log_lowp_path}") + print(f"saved modified trace to {trace_lowp_modified_path}") + lowp_times = profiler_output_to_filtered_time_by_kernel_name( p, profile_iters, num_leaf_tensors ) total_time_ms = ( - sum(v for v in float8_times.values()) / 1e3 / profile_iters + sum(v for v in lowp_times.values()) / 1e3 / profile_iters ) - for k, v in float8_times.items(): + for k, v in lowp_times.items(): v_ms = v / 1e3 / profile_iters data.append( [ - "1_float8", + "1_lowp", k, kernel_name_to_category(k), v / 1e3 / profile_iters, @@ -509,6 +538,7 @@ def float8_forw_backward_wrapper(x): # print the redirected stdout back to regular stdout print(f.getvalue()) + # TODO(future PR): this seems to no longer work, fix it or delete it if os.environ.get("TORCHINDUCTOR_PROFILE", "") != "": # populate the triton kernel bandwidth for line in f.getvalue().split("\n"): @@ -546,13 +576,13 @@ def float8_forw_backward_wrapper(x): fill_value=0, margins=True, ) - # drop last row, which has totals across ref + float8 which does not make sense + # drop last row, which has totals across ref + lowp which does not make sense df_p = df_p[:-1] df_p = df_p.transpose() - if dtype_filter == "both": - df_p["f8_div_ref"] = df_p["1_float8"] / df_p["0_ref"] - df_p["ref_div_f8"] = df_p["0_ref"] / df_p["1_float8"] + if experiment_filter == "both": + df_p["lowp_div_ref"] = df_p["1_lowp"] / df_p["0_ref"] + df_p["ref_div_lowp"] = df_p["0_ref"] / df_p["1_lowp"] print("\nSummary of time (ms) by kernel category\n\n", df_p) diff --git a/benchmarks/float8/utils.py b/benchmarks/float8/utils.py index 60e402e60e..a7faf4757d 100644 --- a/benchmarks/float8/utils.py +++ b/benchmarks/float8/utils.py @@ -73,14 +73,6 @@ def profiler_output_to_filtered_time_by_kernel_name( # forward pass sum assert e.count == num_iter, f"unexpected number of iter for {e.key}" continue - elif e.key == "aten::fill_": - # filling the forward pass sum with 1.0 - assert e.count == num_iter, f"unexpected number of iter for {e.key}" - continue - elif e.key == "aten::copy_": - # copying 1.0 from grad_out of `sum` to grad_out of next op - assert e.count == num_iter, f"unexpected number of iter for {e.key}" - continue elif e.key == "aten::add_": # accumulating gradients into leaf tensors assert e.count == ( @@ -110,25 +102,16 @@ def profiler_output_to_gpu_time_for_key(prof, key): def kernel_name_to_category(k): # number prefix is for easy sorting - if k in ("aten::mm", "aten::addmm", "aten::_scaled_mm"): - return "0_gemm" - elif ( - # max(abs(tensor)) - ("abs" in k and "max" in k) - or - # casting pointwise to float8 - ("clamp" in k) - or - # things related to scaled_mm - ("scaled_mm" in k) - or - # syncing amaxes and scales - ("roll" in k) + if k in ( + "aten::mm", + "aten::addmm", + "aten::_scaled_mm", + "torchao::mx_fp8_bf16", + "torchao::mx_fp4_bf16", ): - # note: the above filter is approximate and will give false - # positives if model code contains other code to abs/max/clamp - return "1_f8_overhead" - return "2_other" + return "0_gemm" + else: + return "1_other" def parse_bw_and_kernel_name(line): diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index d511d2614d..de7369c1cf 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Optional +from typing import Any, Optional, Union import torch @@ -27,6 +27,14 @@ class MXGemmKernelChoice(Enum): # TODO(future PR): add cuBLAS here once we land pytorch/pytorch support +# Pre-made recipes for common configurations +class MXLinearRecipeName(Enum): + MXFP8_EMULATED = "mxfp8_emulated" + MXFP8_CUTLASS = "mxfp8_cutlass" + MXFP4_EMULATED = "mxfp4_emulated" + MXFP4_CUTLASS = "mxfp4_cutlass" + + @dataclass class MXLinearConfig: # block size for scaling, default is 32 to match @@ -78,3 +86,31 @@ def __post_init__(self): assert ( self.elem_dtype_grad_output_override is None ), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels" + + @staticmethod + def from_recipe_name( + recipe_name: Union[MXLinearRecipeName, str], + ) -> "MXLinearConfig": + """ + Input: `MXLinearRecipeName` value, or a string representing a `MXLinearRecipeName` value + Output: a `MXLinearConfig` configured to implement the specified recipe + """ + if type(recipe_name) == str: + valid_names = [n.value for n in MXLinearRecipeName] + assert ( + recipe_name in valid_names + ), f"recipe_name {recipe_name} not in valid names {valid_names}" + recipe_name = MXLinearRecipeName(recipe_name) + + if recipe_name is MXLinearRecipeName.MXFP8_EMULATED: + return MXLinearConfig() + elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS: + return MXLinearConfig(gemm_kernel_choice=MXGemmKernelChoice.CUTLASS) + elif recipe_name is MXLinearRecipeName.MXFP4_EMULATED: + return MXLinearConfig(elem_dtype=DTYPE_FP4) + elif recipe_name is MXLinearRecipeName.MXFP8_CUTLASS: + return MXLinearConfig( + elem_dtype=DTYPE_FP4, gemm_kernel_choice=MXGemmKernelChoice.CUTLASS + ) + else: + raise AssertionError(f"unknown recipe_name {recipe_name}") From bac039fc84867e128db860107ae21283ef1a763e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 24 Feb 2025 12:45:41 -0800 Subject: [PATCH 111/115] Update README.md (#1758) --- torchao/quantization/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index a0e2ea2cc4..d2b6e0c016 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -348,6 +348,8 @@ Marlin QQQ is an optimized GPU kernel that supports W4A8 mixed precision GEMM. F ### Gemlite Triton Int4 and Int8 quantization using the [Gemlite Triton](https://github.com/mobiusml/gemlite) kernels. You can try it out with the `quantize_` api as above alongside the constructor `gemlite_uintx_weight_only`. An example can be found in `torchao/_models/llama/generate.py`. +Note: we test on gemlite 0.4.1, but should be able to use any version after that, we'd recommend to use the latest release to get the most recent performance improvements. + ### UINTx Quantization We're trying to develop kernels for low bit quantization for intx quantization formats. While the current performance is not ideal, we're hoping to continue to iterate on these kernels to improve their performance. From 09ebb120dab3bfb822447c1d0ae904c63c1c749c Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 24 Feb 2025 12:46:11 -0800 Subject: [PATCH 112/115] mx bench: add cast with to_blocked (#1771) Update [ghstack-poisoned] --- benchmarks/float8/profile_lowp_training.py | 33 ++++++++++++++++++---- torchao/prototype/mx_formats/mx_ops.py | 1 + 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/benchmarks/float8/profile_lowp_training.py b/benchmarks/float8/profile_lowp_training.py index ab242f4051..dd629e7f95 100644 --- a/benchmarks/float8/profile_lowp_training.py +++ b/benchmarks/float8/profile_lowp_training.py @@ -48,6 +48,7 @@ from torchao.prototype.mx_formats.config import MXLinearConfig from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.prototype.mx_formats.utils import to_blocked # don't truncate long kernel names pd.options.display.max_colwidth = 100 @@ -298,11 +299,15 @@ def main( "lowp", "ref", ), "experiment_filter must be one of `both`, `lowp`, `ref`" - assert mode_filter in ( - "fwd_bwd", - "fwd", - "cast_only", - ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`" + assert ( + mode_filter + in ( + "fwd_bwd", + "fwd", + "cast_only", + "cast_with_to_blocked", + ) + ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`" if mode_filter == "cast_only": assert experiment_filter == "lowp", "unsupported" @@ -378,6 +383,18 @@ def main( # this function is only used for cast_only to_mx_func = MXTensor.to_mx + # this function is used for cast_with_to_blocked + def cast_with_to_blocked(x_hp): + x_mx = MXTensor.to_mx( + x_hp, + config.elem_dtype, + config.block_size, + gemm_kernel_choice=config.gemm_kernel_choice, + ) + m, k = x_hp.shape + scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size)) + return x_mx._data, scale_blocked + print("m_ref", m_ref) print("m_lowp", m_lowp) print("input_tensor.shape", input_tensor.shape) @@ -385,7 +402,7 @@ def main( print() def ref_forw_backward(x): - assert mode_filter != "cast_only", "unsupported" + assert mode_filter not in ("cast_only", "cast_with_to_blocked"), "unsupported" if enable_activation_checkpointing: out = checkpoint(m_ref, x, use_reentrant=False, context_fn=context_fn) else: @@ -403,6 +420,9 @@ def lowp_forw_backward_wrapper(x): gemm_kernel_choice=config.gemm_kernel_choice, ) return + elif mode_filter == "cast_with_to_blocked": + _input_tensor_mx, scale = cast_with_to_blocked(input_tensor) + return if enable_activation_checkpointing: out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn) @@ -416,6 +436,7 @@ def lowp_forw_backward_wrapper(x): m_ref = torch.compile(m_ref, fullgraph=True) m_lowp = torch.compile(m_lowp, fullgraph=True) to_mx_func = torch.compile(to_mx_func, fullgraph=True) + cast_with_to_blocked = torch.compile(cast_with_to_blocked, fullgraph=True) # if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output # to populate triton kernel bandwidth further down in the script diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 16e61e0653..ddc2bcd665 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -74,6 +74,7 @@ def mx_mm(aten_op, args, kwargs=None): # real MX gemm backed by torchao's CUTLASS kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] assert b._data.t().is_contiguous() + # TODO(future PR): use block_size instead of hardcoding 32 a_scale = a._scale_e8m0.view(M, K // 32) b_scale = b._scale_e8m0.view(N, K // 32) a_scale_block = to_blocked(a_scale) From 089cd7e1e7cc6beba5115f04a2c5c08be7bdfe19 Mon Sep 17 00:00:00 2001 From: eellison Date: Mon, 24 Feb 2025 21:57:21 +0000 Subject: [PATCH 113/115] update mixed mm weight only quant test to work w mixed mm deletion (#1772) We're deleting mixed_mm path in https://github.com/pytorch/pytorch/pull/147151. update test to not check for mixed_mm kernel. Pull Request resolved: https://github.com/pytorch/ao/pull/1772 Approved by: https://github.com/drisspg --- test/integration/test_integration.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 7fd96e4d97..4eccdc86e2 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1243,8 +1243,6 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype): y_wo, (code,) = run_and_get_code(m_c, x) sqnr = compute_error(y_ref, y_wo) self.assertGreaterEqual(sqnr, 38) - if device == "cuda": - self.assertTrue("mixed_mm" in code, f"got code: {code}") @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") From 38e36ded525472cfaf70945209ca49763778d71e Mon Sep 17 00:00:00 2001 From: Facebook Community Bot Date: Mon, 24 Feb 2025 14:36:33 -0800 Subject: [PATCH 114/115] Auto-fix lint violations from Fixit] fbcode//pytorch/ao (#1752) Auto-fix lint violations from Fixit] fbcode//pytorch/ao (#1752) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/1752 Reviewed By: amyreese Differential Revision: D69041228 Co-authored-by: CodemodService Bot --- torchao/quantization/GPTQ.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index cb7c8d0481..b278e22b3b 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -759,7 +759,7 @@ def _create_quantized_state_dict( if self.padding_allowed: import torch.nn.functional as F - logging.warn( + logging.warning( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) padded_in_features = find_multiple(in_features, 1024) @@ -767,7 +767,7 @@ def _create_quantized_state_dict( weight, pad=(0, padded_in_features - in_features) ) else: - logging.warn( + logging.warning( f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + "and that groupsize and inner_k_tiles*16 evenly divide into it" ) @@ -1147,7 +1147,7 @@ def _create_quantized_state_dict( if self.padding_allowed: import torch.nn.functional as F - logging.warn( + logging.warning( f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" ) padded_in_features = find_multiple(in_features, 1024) @@ -1155,7 +1155,7 @@ def _create_quantized_state_dict( weight, pad=(0, padded_in_features - in_features) ) else: - logging.warn( + logging.warning( f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + "and that groupsize and inner_k_tiles*16 evenly divide into it" ) From 98c4e2e06d7f9da57a417a888971820d28eec397 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Tue, 25 Feb 2025 14:46:19 -0500 Subject: [PATCH 115/115] Fix potential out-of-bound access in int8_mm.py (#1751) * fix potential out-of-bound access * remove unused EVEN_K * refactor fix with triton.heuristics * restore EVEN_K as an input * fix typo * fix another typo * ruff reformatted --- torchao/prototype/quantized_training/int8_mm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/quantized_training/int8_mm.py b/torchao/prototype/quantized_training/int8_mm.py index 7de6620d65..faaa6e463e 100644 --- a/torchao/prototype/quantized_training/int8_mm.py +++ b/torchao/prototype/quantized_training/int8_mm.py @@ -54,6 +54,7 @@ @triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"]) +@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0}) @triton.jit def _scaled_int8_mm_kernel( A_ptr, @@ -176,7 +177,6 @@ def scaled_int8_mm_cuda(A: Tensor, B: Tensor, row_scale: Tensor, col_scale: Tens *A.stride(), *B.stride(), *C.stride(), - EVEN_K=K % 2 == 0, COL_SCALE_SCALAR=col_scale.numel() == 1, ) return C