Skip to content

Commit

Permalink
Initial stab at string based config parser
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Feb 25, 2025
1 parent 38e36de commit c9f0b11
Show file tree
Hide file tree
Showing 2 changed files with 358 additions and 0 deletions.
132 changes: 132 additions & 0 deletions test/quantization/test_config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import unittest

import torch

from torchao.quantization.config_parser import ConfigParser
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
MappingType,
UIntXWeightOnlyConfig,
)


class TestConfigParser(unittest.TestCase):
def setUp(self):
self.parser = ConfigParser()

def test_int4wo_config(self):
# Basic Int4WeightOnlyConfig
config = self.parser.parse("int4wo_g32")
self.assertIsInstance(config, Int4WeightOnlyConfig)
self.assertEqual(config.group_size, 32)

# With symmetry specified
config = self.parser.parse("int4wo_g64")
self.assertIsInstance(config, Int4WeightOnlyConfig)
self.assertEqual(config.group_size, 64)

def test_int8wo_config(self):
# Basic Int8WeightOnlyConfig
config = self.parser.parse("int8wo_g128")
self.assertIsInstance(config, Int8WeightOnlyConfig)
self.assertEqual(config.group_size, 128)

# Verify that symmetry parameter raises error since not supported
with self.assertRaises(ValueError) as context:
self.parser.parse("int8wo_g128_sym")

self.assertIn(
"Invalid parameters for Int8WeightOnlyConfig", str(context.exception)
)
self.assertIn("mapping_type", str(context.exception))

def test_int8dqint4_config(self):
# Int8 dynamic activation with Int4 weight
config = self.parser.parse("int8dqint4_g32")
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig)
self.assertEqual(config.group_size, 32)

# With symmetry
config = self.parser.parse("int8dqint4_g32_sym")
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig)
self.assertEqual(config.group_size, 32)
self.assertEqual(config.mapping_type, MappingType.SYMMETRIC)

def test_int8dqint8_config(self):
# Int8 dynamic activation with Int8 weight
config = self.parser.parse("int8dqint8")
self.assertIsInstance(config, Int8DynamicActivationInt8WeightConfig)

def test_int4dqint4_config(self):
# Int4 dynamic activation with Int4 weight
config = self.parser.parse("int4dqint4_sym")
self.assertIsInstance(config, Int4DynamicActivationInt4WeightConfig)
self.assertEqual(config.mapping_type, MappingType.SYMMETRIC)

def test_float8wo_config(self):
# Basic Float8WeightOnlyConfig with e4m3 dtype
config = self.parser.parse("float8wo_e4m3")
self.assertIsInstance(config, Float8WeightOnlyConfig)
self.assertEqual(config.weight_dtype, torch.float8_e4m3fn)

# With e5m2 dtype
config = self.parser.parse("float8wo_e5m2")
self.assertIsInstance(config, Float8WeightOnlyConfig)
self.assertEqual(config.weight_dtype, torch.float8_e5m2)

def test_float8dqfloat8_config(self):
# Float8 dynamic activation with Float8 weight
config = self.parser.parse("float8dqfloat8_e4m3")
self.assertIsInstance(config, Float8DynamicActivationFloat8WeightConfig)
self.assertEqual(config.activation_dtype, torch.float8_e4m3fn)
self.assertEqual(config.weight_dtype, torch.float8_e4m3fn)

def test_uintxwo_config(self):
# UIntX config with uint4
config = self.parser.parse("uintxwo_uint4_g32")
self.assertIsInstance(config, UIntXWeightOnlyConfig)

# With uint8
config = self.parser.parse("uintxwo_uint8_g64")
self.assertIsInstance(config, UIntXWeightOnlyConfig)

# def test_fpx_config(self):
# # FPX config
# config = self.parser.parse("fpx_e4m3")
# self.assertIsInstance(config, FPXWeightOnlyConfig)
# self.assertEqual(config.ebits, 4)
# self.assertEqual(config.mbits, 3)

def test_invalid_config_string(self):
# Test empty string
with self.assertRaises(ValueError):
self.parser.parse("")

# Test unknown base config
with self.assertRaises(ValueError):
self.parser.parse("unknown_config")

# Test invalid parameter token
with self.assertRaises(ValueError):
self.parser.parse("int4wo_invalid_token")

def test_complex_configurations(self):
# Adjust tests for complex configurations to match actual parameter names
config = self.parser.parse("int4wo_g32")
self.assertIsInstance(config, Int4WeightOnlyConfig)
self.assertEqual(config.group_size, 32)

config = self.parser.parse("int8dqint4_g32_asym")
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig)
self.assertEqual(config.group_size, 32)
self.assertEqual(config.mapping_type, MappingType.ASYMMETRIC)


if __name__ == "__main__":
unittest.main()
226 changes: 226 additions & 0 deletions torchao/quantization/config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import dataclasses
import re
from typing import Any, Dict, List, Protocol, Tuple, Type

import torch

from torchao.core.config import AOBaseConfig
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
MappingType,
# Float8StaticActivationFloat8WeightConfig,
UIntXWeightOnlyConfig,
)

# Create a type alias for AOBaseConfig classes
ConfigType = Type[AOBaseConfig]


# Define a Protocol for parameter processors
class ParameterProcessor(Protocol):
"""Protocol defining the interface for parameter processors"""

def __call__(self, match: re.Match, quant_config: ConfigType) -> Tuple[str, Any]:
"""
Process a regex match into a parameter name and value
Args:
match: The regex match object containing captured groups
quant_config: The quantization config class being instantiated
Returns:
Tuple of (parameter_name, parameter_value)
Note:
If you need special handling based on the quant_config type,
be sure to use issubclass instead of isinstance.
"""
...


def process_bits(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
return "bits", int(match.group(1))


def process_group_size(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
return "group_size", int(match.group(1))


def process_activation_bits(
match: re.Match, quant_config: AOBaseConfig
) -> Tuple[str, Any]:
return "activation_bits", int(match.group(1))


def process_weight_bits(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
return "weight_bits", int(match.group(1))


def process_symmetry(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
mapping_type = (
MappingType.SYMMETRIC if match.group(1) == "sym" else MappingType.ASYMMETRIC
)
return "mapping_type", mapping_type


def process_dtype(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
dtype_map = {
"int4": torch.int4,
"int8": torch.int8,
"uint4": torch.uint4,
"uint8": torch.uint8,
"e4m3": torch.float8_e4m3fn,
"e5m2": torch.float8_e5m2,
}
# The float8's have different key names :(
key = (
"weight_dtype"
if issubclass(
quant_config,
(
Float8WeightOnlyConfig,
Float8DynamicActivationFloat8WeightConfig,
),
)
else "dtype"
)
return key, dtype_map[match.group(1)]


def process_per_row(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
return "per_row", True


class ConfigParser:
"""Parser for string-based configuration patterns"""

# Parameter patterns with their processing functions
param_patterns: Dict[re.Pattern, ParameterProcessor] = {
re.compile(r"(\d+)bit"): process_bits,
re.compile(r"g(\d+)"): process_group_size,
# re.compile(r"act(\d+)"): process_activation_bits,
# re.compile(r"w(\d+)"): process_weight_bits,
re.compile(r"(sym|asym)"): process_symmetry,
re.compile(r"(int4|int8|uint4|uint8|e4m3|e5m2)"): process_dtype,
re.compile(r"(per_row)"): process_per_row,
}

# Map from string prefix to QuantType
type_mapping = {
"int4wo": Int4WeightOnlyConfig,
"int8wo": Int8WeightOnlyConfig,
"int8dqint4": Int8DynamicActivationInt4WeightConfig,
"int8dqint8": Int8DynamicActivationInt8WeightConfig,
"int4dqint4": Int4DynamicActivationInt4WeightConfig,
"float8wo": Float8WeightOnlyConfig,
"float8dqfloat8": Float8DynamicActivationFloat8WeightConfig,
# "float8staticfloat8": Float8StaticActivationFloat8WeightConfig,
"uintxwo": UIntXWeightOnlyConfig,
# "fpx": FPXWeightOnlyConfig,
}

def parse(self, config_str: str) -> AOBaseConfig:
"""
Parse a configuration string into an AO quantization configuration object.
This is the main entrypoint for converting string-based configuration into actual config objects.
The expected format is "base_param1-value1_param2-value2" where "base" identifies the base
quantization type and subsequent tokens specify parameter values.
Examples:
config_parser.parse("int8dqint8")
config_parser.parse("int8dqint4_g32")
Args:
config_str: String representation of the quantization configuration
Returns:
AOBaseConfig: Instantiated quantization configuration object
Raises:
ValueError: If the config string is empty or invalid
"""
tokens = config_str.split("_")

if not tokens:
raise ValueError("Empty config string")

# The first token is the base quantization type
quant_config = self._get_config(tokens[0])

# We know the base quant type, now we convert each token to its parameter
params = self._extract_params(quant_config, tokens[1:])

return self._instantiate_config(quant_config, params)

def _get_config(self, first_token: str) -> AOBaseConfig:
"""Get the quantization config from a string"""
try:
quant_config = self.type_mapping[first_token]
except KeyError:
# Print available base configurations before raising error
available_configs = list(self.type_mapping.keys())
raise ValueError(
f"Unknown quantization type in string: {first_token} \n Available base configurations: {available_configs}"
)
return quant_config

def _instantiate_config(
self, quant_config: AOBaseConfig, params: Dict[str, Any]
) -> AOBaseConfig:
"""Sprinkle some extra logic for helping w/ instantiation failures"""
try:
return quant_config(**params)
except TypeError as e:
# Get proper field information for error message
valid_fields = {
field.name
for field in dataclasses.fields(quant_config)
if field.name != "self"
}
invalid_params = {k: v for k, v in params.items() if k not in valid_fields}

field_info = [field.name for field in dataclasses.fields(quant_config)]

raise ValueError(
f"Invalid parameters for {quant_config.__name__}: {list(invalid_params.keys())}.\n"
f"Available parameters for {quant_config.__name__}: {field_info}"
) from e

def _extract_params(
self, quant_config: AOBaseConfig, param_tokens: List[str]
) -> Dict[str, Any]:
"""Extract parameters from tokens"""
params = {}

for token in param_tokens:
if not token:
continue

matched = False
# Try to match against parameter patterns
# We could specify an ordering but for now we just try all
for pattern, processor in self.param_patterns.items():
match = pattern.fullmatch(token)
if match:
param_name, value = processor(match, quant_config)
params[param_name] = value
matched = True
break

if not matched:
field_info = [
(field.name, field.type)
for field in dataclasses.fields(quant_config)
]
raise ValueError(
f"Unrecognized parameter token: {token} in {param_tokens}\nAvailable parameters for {quant_config.__name__}: {field_info}"
)

return params

0 comments on commit c9f0b11

Please sign in to comment.