Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Feb 26, 2025
2 parents 77c8628 + 64647f1 commit 4bed185
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 81 deletions.
2 changes: 1 addition & 1 deletion benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from torchao.float8 import (
convert_to_float8_training,
)
from torchao.float8.roofline_utils import (
from torchao.testing.float8.roofline_utils import (
get_float8_mem_sympy,
get_gemm_time_sympy,
)
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
)
from torchao.float8.float8_python_api import addmm_float8_unwrapped
from torchao.float8.float8_ops import addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import (
get_maybe_axiswise_dim,
hp_tensor_to_float8_dynamic,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

#if defined(TORCHAO_ENABLE_KLEIDI)
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>
using namespace torchao::kernels::cpu::aarch64::kleidi::
kai_matmul_clamp_f32_qai8dxp_qsi4c32p;
#endif // TORCHAO_ENABLE_KLEIDI

const float kTol = 1.0e-5;

using namespace torchao::ops::linear_8bit_act_xbit_weight;
using namespace torchao::kernels::cpu::aarch64::kleidi::
kai_matmul_clamp_f32_qai8dxp_qsi4c32p;

template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
UKernelConfig get_ukernel_config() {
Expand Down
64 changes: 62 additions & 2 deletions torchao/float8/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple

import torch
from torch.utils._pytree import tree_map

from torchao.float8.float8_python_api import addmm_float8_unwrapped
from torchao.float8.float8_tensor import Float8Tensor, choose_scaled_mm_config
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul

Expand All @@ -18,6 +17,67 @@
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}


# [Note] Usage of scales
# The meaning of scale in this library can be found in the definition of the Float8Tensor
# Cublas defines scale to always mean a multiplicative factor for the respective matrices
# For a,b going from fp8 -> fp32 we multiple by the inverse of the scale
# For output going from fp32 -> fp8 we multiply by the scale
def addmm_float8_unwrapped(
a_data: torch.Tensor,
a_scale: torch.Tensor,
b_data: torch.Tensor,
b_scale: torch.tensor,
output_dtype: torch.dtype,
output_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_fast_accum: bool = False,
) -> torch.Tensor:
"""
This is the unwrapped version of addmm_float8, which does not take in Float8Tensors
as inputs. This is used to standardize the logic between subclassed and non subclassed
versions of the linear module.
"""
a_inverse_scale = a_scale.reciprocal()
b_inverse_scale = b_scale.reciprocal()

post_inverse_scale = None
if (
a_scale.shape == (a_data.shape[0], 1)
and b_scale.shape == (1, b_data.shape[1])
and not use_fast_accum
):
# The rowwise CUTLASS-based kernel is so slow without fast-accum that
# we'd rather use the tensorwise cuBLAS-based kernel and do the scaling
# manually afterwards (hoping Inductor will be able to fuse it).
post_inverse_scale = a_inverse_scale * b_inverse_scale
a_inverse_scale = a_inverse_scale.new_ones(())
b_inverse_scale = a_inverse_scale.new_ones(())

post_bias = None
if output_dtype == torch.float32:
# Bias is not supported by _scaled_mm when output is fp32
post_bias = bias
bias = None

output = torch._scaled_mm(
a_data,
b_data,
scale_a=a_inverse_scale,
scale_b=b_inverse_scale,
bias=bias,
scale_result=output_scale,
out_dtype=output_dtype,
use_fast_accum=use_fast_accum,
)

if post_inverse_scale is not None:
output *= post_inverse_scale
if post_bias is not None:
output += post_bias

return output


def _assert_tensorwise_scale(aten_op, scale):
assert (
# TODO(future PR): figure out why tensorwise scaling can have
Expand Down
75 changes: 0 additions & 75 deletions torchao/float8/float8_python_api.py

This file was deleted.

File renamed without changes.

0 comments on commit 4bed185

Please sign in to comment.