Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial stab at string based config parser #1774

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions test/quantization/test_config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import unittest

import torch

from torchao.quantization.config_parser import ConfigParser
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
MappingType,
UIntXWeightOnlyConfig,
)


class TestConfigParser(unittest.TestCase):
def setUp(self):
self.parser = ConfigParser()

def test_int4wo_config(self):
# Basic Int4WeightOnlyConfig
config = self.parser.parse("int4wo_g32")
self.assertIsInstance(config, Int4WeightOnlyConfig)
self.assertEqual(config.group_size, 32)

# With symmetry specified
config = self.parser.parse("int4wo_g64")
self.assertIsInstance(config, Int4WeightOnlyConfig)
self.assertEqual(config.group_size, 64)

def test_int8wo_config(self):
# Basic Int8WeightOnlyConfig
config = self.parser.parse("int8wo_g128")
self.assertIsInstance(config, Int8WeightOnlyConfig)
self.assertEqual(config.group_size, 128)

# Verify that symmetry parameter raises error since not supported
with self.assertRaises(ValueError) as context:
self.parser.parse("int8wo_g128_sym")

self.assertIn(
"Invalid parameters for Int8WeightOnlyConfig", str(context.exception)
)
self.assertIn("mapping_type", str(context.exception))

def test_int8dqint4_config(self):
# Int8 dynamic activation with Int4 weight
config = self.parser.parse("int8dqint4_g32")
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig)
self.assertEqual(config.group_size, 32)

# With symmetry
config = self.parser.parse("int8dqint4_g32_sym")
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig)
self.assertEqual(config.group_size, 32)
self.assertEqual(config.mapping_type, MappingType.SYMMETRIC)

def test_int8dqint8_config(self):
# Int8 dynamic activation with Int8 weight
config = self.parser.parse("int8dqint8")
self.assertIsInstance(config, Int8DynamicActivationInt8WeightConfig)

def test_int4dqint4_config(self):
# Int4 dynamic activation with Int4 weight
config = self.parser.parse("int4dqint4_sym")
self.assertIsInstance(config, Int4DynamicActivationInt4WeightConfig)
self.assertEqual(config.mapping_type, MappingType.SYMMETRIC)

def test_float8wo_config(self):
# Basic Float8WeightOnlyConfig with e4m3 dtype
config = self.parser.parse("float8wo_e4m3")
self.assertIsInstance(config, Float8WeightOnlyConfig)
self.assertEqual(config.weight_dtype, torch.float8_e4m3fn)

# With e5m2 dtype
config = self.parser.parse("float8wo_e5m2")
self.assertIsInstance(config, Float8WeightOnlyConfig)
self.assertEqual(config.weight_dtype, torch.float8_e5m2)

def test_float8dqfloat8_config(self):
# Float8 dynamic activation with Float8 weight
config = self.parser.parse("float8dqfloat8_e4m3")
self.assertIsInstance(config, Float8DynamicActivationFloat8WeightConfig)
self.assertEqual(config.activation_dtype, torch.float8_e4m3fn)
self.assertEqual(config.weight_dtype, torch.float8_e4m3fn)

def test_uintxwo_config(self):
# UIntX config with uint4
config = self.parser.parse("uintxwo_uint4_g32")
self.assertIsInstance(config, UIntXWeightOnlyConfig)

# With uint8
config = self.parser.parse("uintxwo_uint8_g64")
self.assertIsInstance(config, UIntXWeightOnlyConfig)

# def test_fpx_config(self):
# # FPX config
# config = self.parser.parse("fpx_e4m3")
# self.assertIsInstance(config, FPXWeightOnlyConfig)
# self.assertEqual(config.ebits, 4)
# self.assertEqual(config.mbits, 3)

def test_invalid_config_string(self):
# Test empty string
with self.assertRaises(ValueError):
self.parser.parse("")

# Test unknown base config
with self.assertRaises(ValueError):
self.parser.parse("unknown_config")

# Test invalid parameter token
with self.assertRaises(ValueError):
self.parser.parse("int4wo_invalid_token")

def test_complex_configurations(self):
# Adjust tests for complex configurations to match actual parameter names
config = self.parser.parse("int4wo_g32")
self.assertIsInstance(config, Int4WeightOnlyConfig)
self.assertEqual(config.group_size, 32)

config = self.parser.parse("int8dqint4_g32_asym")
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig)
self.assertEqual(config.group_size, 32)
self.assertEqual(config.mapping_type, MappingType.ASYMMETRIC)


if __name__ == "__main__":
unittest.main()
226 changes: 226 additions & 0 deletions torchao/quantization/config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import dataclasses
import re
from typing import Any, Dict, List, Protocol, Tuple, Type

import torch

from torchao.core.config import AOBaseConfig
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
MappingType,
# Float8StaticActivationFloat8WeightConfig,
UIntXWeightOnlyConfig,
)

# Create a type alias for AOBaseConfig classes
ConfigType = Type[AOBaseConfig]


# Define a Protocol for parameter processors
class ParameterProcessor(Protocol):
"""Protocol defining the interface for parameter processors"""

def __call__(self, match: re.Match, quant_config: ConfigType) -> Tuple[str, Any]:
"""
Process a regex match into a parameter name and value

Args:
match: The regex match object containing captured groups
quant_config: The quantization config class being instantiated

Returns:
Tuple of (parameter_name, parameter_value)

Note:
If you need special handling based on the quant_config type,
be sure to use issubclass instead of isinstance.
"""
...


def process_bits(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
return "bits", int(match.group(1))


def process_group_size(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
return "group_size", int(match.group(1))


def process_activation_bits(
match: re.Match, quant_config: AOBaseConfig
) -> Tuple[str, Any]:
return "activation_bits", int(match.group(1))


def process_weight_bits(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
return "weight_bits", int(match.group(1))


def process_symmetry(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
mapping_type = (
MappingType.SYMMETRIC if match.group(1) == "sym" else MappingType.ASYMMETRIC
)
return "mapping_type", mapping_type


def process_dtype(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
dtype_map = {
"int4": torch.int4,
"int8": torch.int8,
"uint4": torch.uint4,
"uint8": torch.uint8,
"e4m3": torch.float8_e4m3fn,
"e5m2": torch.float8_e5m2,
}
# The float8's have different key names :(
key = (
"weight_dtype"
if issubclass(
quant_config,
(
Float8WeightOnlyConfig,
Float8DynamicActivationFloat8WeightConfig,
),
)
else "dtype"
)
return key, dtype_map[match.group(1)]


def process_per_row(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]:
return "per_row", True


class ConfigParser:
"""Parser for string-based configuration patterns"""

# Parameter patterns with their processing functions
param_patterns: Dict[re.Pattern, ParameterProcessor] = {
re.compile(r"(\d+)bit"): process_bits,
re.compile(r"g(\d+)"): process_group_size,
# re.compile(r"act(\d+)"): process_activation_bits,
# re.compile(r"w(\d+)"): process_weight_bits,
re.compile(r"(sym|asym)"): process_symmetry,
re.compile(r"(int4|int8|uint4|uint8|e4m3|e5m2)"): process_dtype,
re.compile(r"(per_row)"): process_per_row,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, one more thing is can you add layout here as well:

layout: Optional[TensorCoreTiledLayout] = TensorCoreTiledLayout(inner_k_tiles=8)
this will help with huggingface/transformers#36146

}

# Map from string prefix to QuantType
type_mapping = {
"int4wo": Int4WeightOnlyConfig,
"int8wo": Int8WeightOnlyConfig,
"int8dqint4": Int8DynamicActivationInt4WeightConfig,
"int8dqint8": Int8DynamicActivationInt8WeightConfig,
"int4dqint4": Int4DynamicActivationInt4WeightConfig,
"float8wo": Float8WeightOnlyConfig,
"float8dqfloat8": Float8DynamicActivationFloat8WeightConfig,
# "float8staticfloat8": Float8StaticActivationFloat8WeightConfig,
"uintxwo": UIntXWeightOnlyConfig,
# "fpx": FPXWeightOnlyConfig,
}

def parse(self, config_str: str) -> AOBaseConfig:
"""
Parse a configuration string into an AO quantization configuration object.

This is the main entrypoint for converting string-based configuration into actual config objects.
The expected format is "base_param1-value1_param2-value2" where "base" identifies the base
quantization type and subsequent tokens specify parameter values.

Examples:
config_parser.parse("int8dqint8")
config_parser.parse("int8dqint4_g32")

Args:
config_str: String representation of the quantization configuration

Returns:
AOBaseConfig: Instantiated quantization configuration object

Raises:
ValueError: If the config string is empty or invalid
"""
tokens = config_str.split("_")

if not tokens:
raise ValueError("Empty config string")

# The first token is the base quantization type
quant_config = self._get_config(tokens[0])

# We know the base quant type, now we convert each token to its parameter
params = self._extract_params(quant_config, tokens[1:])

return self._instantiate_config(quant_config, params)

def _get_config(self, first_token: str) -> AOBaseConfig:
"""Get the quantization config from a string"""
try:
quant_config = self.type_mapping[first_token]
except KeyError:
# Print available base configurations before raising error
available_configs = list(self.type_mapping.keys())
raise ValueError(
f"Unknown quantization type in string: {first_token} \n Available base configurations: {available_configs}"
)
return quant_config

def _instantiate_config(
self, quant_config: AOBaseConfig, params: Dict[str, Any]
) -> AOBaseConfig:
"""Sprinkle some extra logic for helping w/ instantiation failures"""
try:
return quant_config(**params)
except TypeError as e:
# Get proper field information for error message
valid_fields = {
field.name
for field in dataclasses.fields(quant_config)
if field.name != "self"
}
invalid_params = {k: v for k, v in params.items() if k not in valid_fields}

field_info = [field.name for field in dataclasses.fields(quant_config)]

raise ValueError(
f"Invalid parameters for {quant_config.__name__}: {list(invalid_params.keys())}.\n"
f"Available parameters for {quant_config.__name__}: {field_info}"
) from e

def _extract_params(
self, quant_config: AOBaseConfig, param_tokens: List[str]
) -> Dict[str, Any]:
"""Extract parameters from tokens"""
params = {}

for token in param_tokens:
if not token:
continue

matched = False
# Try to match against parameter patterns
# We could specify an ordering but for now we just try all
for pattern, processor in self.param_patterns.items():
match = pattern.fullmatch(token)
if match:
param_name, value = processor(match, quant_config)
params[param_name] = value
matched = True
break

if not matched:
field_info = [
(field.name, field.type)
for field in dataclasses.fields(quant_config)
]
raise ValueError(
f"Unrecognized parameter token: {token} in {param_tokens}\nAvailable parameters for {quant_config.__name__}: {field_info}"
)

return params
Loading