Skip to content

Commit

Permalink
Granularoty validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Nov 25, 2024
1 parent b5945d3 commit 436d3aa
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
28 changes: 22 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_MI300,
is_sm_89,
is_sm_90,
)

from .autoquant import AutoQuantizableLinearWeight, autoquant
Expand Down Expand Up @@ -83,7 +85,6 @@
from .utils import _get_per_token_block_size

logger = logging.getLogger(__name__)
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

__all__ = [
"swap_conv2d_1x1_to_linear",
Expand Down Expand Up @@ -829,10 +830,11 @@ def _normalize_granularity(
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
],
) -> Tuple[_fp8_granularities, _fp8_granularities]:
processed_granularity = None
if granularity is None:
return (PerTensor(), PerTensor())
processed_granularity = (PerTensor(), PerTensor())
elif isinstance(granularity, (PerTensor, PerRow)):
return (granularity, granularity)
processed_granularity = (granularity, granularity)
elif isinstance(granularity, tuple) and len(granularity) == 2:
if not (
isinstance(granularity[0], (PerTensor, PerRow))
Expand All @@ -845,11 +847,25 @@ def _normalize_granularity(
raise ValueError(
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
)
return granularity
processed_granularity = granularity
else:
raise ValueError(
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
)
# Validate granularity with supported Hardware
for _granularity in processed_granularity:
if isinstance(_granularity, PerTensor):
assert (
is_sm_89() or is_MI300()
), "PerTensor quantization only works for CUDA>=8.9 and MI300+"
elif isinstance(_granularity, PerRow):
assert (
is_sm_90() or is_MI300()
), "PerRow quantization only works for CUDA>=9.0 and MI300+"
else:
raise ValueError(f"Invalid granularity type: {_granularity}")

return processed_granularity


def _input_activation_quant_func_fp8(
Expand Down Expand Up @@ -942,7 +958,7 @@ def float8_dynamic_activation_float8_weight(
"""
assert (
is_cuda_8_9 or is_MI300()
is_sm_89() or is_MI300()
), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
if mm_config is None:
mm_config = Float8MMConfig(use_fast_accum=True)
Expand Down Expand Up @@ -999,7 +1015,7 @@ def float8_static_activation_float8_weight(
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
"""
assert (
is_cuda_8_9 or is_MI300()
is_sm_89() or is_MI300()
), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
if mm_config is None:
mm_config = Float8MMConfig(use_fast_accum=True)
Expand Down
12 changes: 11 additions & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
"TORCH_VERSION_AFTER_2_4",
"TORCH_VERSION_AFTER_2_5",
"is_MI300",
"is_sm_89",
"is_sm_90",
]


Expand Down Expand Up @@ -597,14 +599,22 @@ def is_MI300():
return False


def is_cuda_8_9():
def is_sm_89():
return (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (8, 9)
)


def is_sm_90():
return (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (9, 0)
)


TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev")
TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev")
TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")
Expand Down

0 comments on commit 436d3aa

Please sign in to comment.