-
Notifications
You must be signed in to change notification settings - Fork 223
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial stab at string based config parser
- Loading branch information
Showing
2 changed files
with
358 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from torchao.quantization.config_parser import ConfigParser | ||
from torchao.quantization.quant_api import ( | ||
Float8DynamicActivationFloat8WeightConfig, | ||
Float8WeightOnlyConfig, | ||
Int4DynamicActivationInt4WeightConfig, | ||
Int4WeightOnlyConfig, | ||
Int8DynamicActivationInt4WeightConfig, | ||
Int8DynamicActivationInt8WeightConfig, | ||
Int8WeightOnlyConfig, | ||
MappingType, | ||
UIntXWeightOnlyConfig, | ||
) | ||
|
||
|
||
class TestConfigParser(unittest.TestCase): | ||
def setUp(self): | ||
self.parser = ConfigParser() | ||
|
||
def test_int4wo_config(self): | ||
# Basic Int4WeightOnlyConfig | ||
config = self.parser.parse("int4wo_g32") | ||
self.assertIsInstance(config, Int4WeightOnlyConfig) | ||
self.assertEqual(config.group_size, 32) | ||
|
||
# With symmetry specified | ||
config = self.parser.parse("int4wo_g64") | ||
self.assertIsInstance(config, Int4WeightOnlyConfig) | ||
self.assertEqual(config.group_size, 64) | ||
|
||
def test_int8wo_config(self): | ||
# Basic Int8WeightOnlyConfig | ||
config = self.parser.parse("int8wo_g128") | ||
self.assertIsInstance(config, Int8WeightOnlyConfig) | ||
self.assertEqual(config.group_size, 128) | ||
|
||
# Verify that symmetry parameter raises error since not supported | ||
with self.assertRaises(ValueError) as context: | ||
self.parser.parse("int8wo_g128_sym") | ||
|
||
self.assertIn( | ||
"Invalid parameters for Int8WeightOnlyConfig", str(context.exception) | ||
) | ||
self.assertIn("mapping_type", str(context.exception)) | ||
|
||
def test_int8dqint4_config(self): | ||
# Int8 dynamic activation with Int4 weight | ||
config = self.parser.parse("int8dqint4_g32") | ||
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig) | ||
self.assertEqual(config.group_size, 32) | ||
|
||
# With symmetry | ||
config = self.parser.parse("int8dqint4_g32_sym") | ||
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig) | ||
self.assertEqual(config.group_size, 32) | ||
self.assertEqual(config.mapping_type, MappingType.SYMMETRIC) | ||
|
||
def test_int8dqint8_config(self): | ||
# Int8 dynamic activation with Int8 weight | ||
config = self.parser.parse("int8dqint8") | ||
self.assertIsInstance(config, Int8DynamicActivationInt8WeightConfig) | ||
|
||
def test_int4dqint4_config(self): | ||
# Int4 dynamic activation with Int4 weight | ||
config = self.parser.parse("int4dqint4_sym") | ||
self.assertIsInstance(config, Int4DynamicActivationInt4WeightConfig) | ||
self.assertEqual(config.mapping_type, MappingType.SYMMETRIC) | ||
|
||
def test_float8wo_config(self): | ||
# Basic Float8WeightOnlyConfig with e4m3 dtype | ||
config = self.parser.parse("float8wo_e4m3") | ||
self.assertIsInstance(config, Float8WeightOnlyConfig) | ||
self.assertEqual(config.weight_dtype, torch.float8_e4m3fn) | ||
|
||
# With e5m2 dtype | ||
config = self.parser.parse("float8wo_e5m2") | ||
self.assertIsInstance(config, Float8WeightOnlyConfig) | ||
self.assertEqual(config.weight_dtype, torch.float8_e5m2) | ||
|
||
def test_float8dqfloat8_config(self): | ||
# Float8 dynamic activation with Float8 weight | ||
config = self.parser.parse("float8dqfloat8_e4m3") | ||
self.assertIsInstance(config, Float8DynamicActivationFloat8WeightConfig) | ||
self.assertEqual(config.activation_dtype, torch.float8_e4m3fn) | ||
self.assertEqual(config.weight_dtype, torch.float8_e4m3fn) | ||
|
||
def test_uintxwo_config(self): | ||
# UIntX config with uint4 | ||
config = self.parser.parse("uintxwo_uint4_g32") | ||
self.assertIsInstance(config, UIntXWeightOnlyConfig) | ||
|
||
# With uint8 | ||
config = self.parser.parse("uintxwo_uint8_g64") | ||
self.assertIsInstance(config, UIntXWeightOnlyConfig) | ||
|
||
# def test_fpx_config(self): | ||
# # FPX config | ||
# config = self.parser.parse("fpx_e4m3") | ||
# self.assertIsInstance(config, FPXWeightOnlyConfig) | ||
# self.assertEqual(config.ebits, 4) | ||
# self.assertEqual(config.mbits, 3) | ||
|
||
def test_invalid_config_string(self): | ||
# Test empty string | ||
with self.assertRaises(ValueError): | ||
self.parser.parse("") | ||
|
||
# Test unknown base config | ||
with self.assertRaises(ValueError): | ||
self.parser.parse("unknown_config") | ||
|
||
# Test invalid parameter token | ||
with self.assertRaises(ValueError): | ||
self.parser.parse("int4wo_invalid_token") | ||
|
||
def test_complex_configurations(self): | ||
# Adjust tests for complex configurations to match actual parameter names | ||
config = self.parser.parse("int4wo_g32") | ||
self.assertIsInstance(config, Int4WeightOnlyConfig) | ||
self.assertEqual(config.group_size, 32) | ||
|
||
config = self.parser.parse("int8dqint4_g32_asym") | ||
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig) | ||
self.assertEqual(config.group_size, 32) | ||
self.assertEqual(config.mapping_type, MappingType.ASYMMETRIC) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
import dataclasses | ||
import re | ||
from typing import Any, Dict, List, Protocol, Tuple, Type | ||
|
||
import torch | ||
|
||
from torchao.core.config import AOBaseConfig | ||
from torchao.quantization.quant_api import ( | ||
Float8DynamicActivationFloat8WeightConfig, | ||
Float8WeightOnlyConfig, | ||
Int4DynamicActivationInt4WeightConfig, | ||
Int4WeightOnlyConfig, | ||
Int8DynamicActivationInt4WeightConfig, | ||
Int8DynamicActivationInt8WeightConfig, | ||
Int8WeightOnlyConfig, | ||
MappingType, | ||
# Float8StaticActivationFloat8WeightConfig, | ||
UIntXWeightOnlyConfig, | ||
) | ||
|
||
# Create a type alias for AOBaseConfig classes | ||
ConfigType = Type[AOBaseConfig] | ||
|
||
|
||
# Define a Protocol for parameter processors | ||
class ParameterProcessor(Protocol): | ||
"""Protocol defining the interface for parameter processors""" | ||
|
||
def __call__(self, match: re.Match, quant_config: ConfigType) -> Tuple[str, Any]: | ||
""" | ||
Process a regex match into a parameter name and value | ||
Args: | ||
match: The regex match object containing captured groups | ||
quant_config: The quantization config class being instantiated | ||
Returns: | ||
Tuple of (parameter_name, parameter_value) | ||
Note: | ||
If you need special handling based on the quant_config type, | ||
be sure to use issubclass instead of isinstance. | ||
""" | ||
... | ||
|
||
|
||
def process_bits(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
return "bits", int(match.group(1)) | ||
|
||
|
||
def process_group_size(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
return "group_size", int(match.group(1)) | ||
|
||
|
||
def process_activation_bits( | ||
match: re.Match, quant_config: AOBaseConfig | ||
) -> Tuple[str, Any]: | ||
return "activation_bits", int(match.group(1)) | ||
|
||
|
||
def process_weight_bits(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
return "weight_bits", int(match.group(1)) | ||
|
||
|
||
def process_symmetry(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
mapping_type = ( | ||
MappingType.SYMMETRIC if match.group(1) == "sym" else MappingType.ASYMMETRIC | ||
) | ||
return "mapping_type", mapping_type | ||
|
||
|
||
def process_dtype(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
dtype_map = { | ||
"int4": torch.int4, | ||
"int8": torch.int8, | ||
"uint4": torch.uint4, | ||
"uint8": torch.uint8, | ||
"e4m3": torch.float8_e4m3fn, | ||
"e5m2": torch.float8_e5m2, | ||
} | ||
# The float8's have different key names :( | ||
key = ( | ||
"weight_dtype" | ||
if issubclass( | ||
quant_config, | ||
( | ||
Float8WeightOnlyConfig, | ||
Float8DynamicActivationFloat8WeightConfig, | ||
), | ||
) | ||
else "dtype" | ||
) | ||
return key, dtype_map[match.group(1)] | ||
|
||
|
||
def process_per_row(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
return "per_row", True | ||
|
||
|
||
class ConfigParser: | ||
"""Parser for string-based configuration patterns""" | ||
|
||
# Parameter patterns with their processing functions | ||
param_patterns: Dict[re.Pattern, ParameterProcessor] = { | ||
re.compile(r"(\d+)bit"): process_bits, | ||
re.compile(r"g(\d+)"): process_group_size, | ||
# re.compile(r"act(\d+)"): process_activation_bits, | ||
# re.compile(r"w(\d+)"): process_weight_bits, | ||
re.compile(r"(sym|asym)"): process_symmetry, | ||
re.compile(r"(int4|int8|uint4|uint8|e4m3|e5m2)"): process_dtype, | ||
re.compile(r"(per_row)"): process_per_row, | ||
} | ||
|
||
# Map from string prefix to QuantType | ||
type_mapping = { | ||
"int4wo": Int4WeightOnlyConfig, | ||
"int8wo": Int8WeightOnlyConfig, | ||
"int8dqint4": Int8DynamicActivationInt4WeightConfig, | ||
"int8dqint8": Int8DynamicActivationInt8WeightConfig, | ||
"int4dqint4": Int4DynamicActivationInt4WeightConfig, | ||
"float8wo": Float8WeightOnlyConfig, | ||
"float8dqfloat8": Float8DynamicActivationFloat8WeightConfig, | ||
# "float8staticfloat8": Float8StaticActivationFloat8WeightConfig, | ||
"uintxwo": UIntXWeightOnlyConfig, | ||
# "fpx": FPXWeightOnlyConfig, | ||
} | ||
|
||
def parse(self, config_str: str) -> AOBaseConfig: | ||
""" | ||
Parse a configuration string into an AO quantization configuration object. | ||
This is the main entrypoint for converting string-based configuration into actual config objects. | ||
The expected format is "base_param1-value1_param2-value2" where "base" identifies the base | ||
quantization type and subsequent tokens specify parameter values. | ||
Examples: | ||
config_parser.parse("int8dqint8") | ||
config_parser.parse("int8dqint4_g32") | ||
Args: | ||
config_str: String representation of the quantization configuration | ||
Returns: | ||
AOBaseConfig: Instantiated quantization configuration object | ||
Raises: | ||
ValueError: If the config string is empty or invalid | ||
""" | ||
tokens = config_str.split("_") | ||
|
||
if not tokens: | ||
raise ValueError("Empty config string") | ||
|
||
# The first token is the base quantization type | ||
quant_config = self._get_config(tokens[0]) | ||
|
||
# We know the base quant type, now we convert each token to its parameter | ||
params = self._extract_params(quant_config, tokens[1:]) | ||
|
||
return self._instantiate_config(quant_config, params) | ||
|
||
def _get_config(self, first_token: str) -> AOBaseConfig: | ||
"""Get the quantization config from a string""" | ||
try: | ||
quant_config = self.type_mapping[first_token] | ||
except KeyError: | ||
# Print available base configurations before raising error | ||
available_configs = list(self.type_mapping.keys()) | ||
raise ValueError( | ||
f"Unknown quantization type in string: {first_token} \n Available base configurations: {available_configs}" | ||
) | ||
return quant_config | ||
|
||
def _instantiate_config( | ||
self, quant_config: AOBaseConfig, params: Dict[str, Any] | ||
) -> AOBaseConfig: | ||
"""Sprinkle some extra logic for helping w/ instantiation failures""" | ||
try: | ||
return quant_config(**params) | ||
except TypeError as e: | ||
# Get proper field information for error message | ||
valid_fields = { | ||
field.name | ||
for field in dataclasses.fields(quant_config) | ||
if field.name != "self" | ||
} | ||
invalid_params = {k: v for k, v in params.items() if k not in valid_fields} | ||
|
||
field_info = [field.name for field in dataclasses.fields(quant_config)] | ||
|
||
raise ValueError( | ||
f"Invalid parameters for {quant_config.__name__}: {list(invalid_params.keys())}.\n" | ||
f"Available parameters for {quant_config.__name__}: {field_info}" | ||
) from e | ||
|
||
def _extract_params( | ||
self, quant_config: AOBaseConfig, param_tokens: List[str] | ||
) -> Dict[str, Any]: | ||
"""Extract parameters from tokens""" | ||
params = {} | ||
|
||
for token in param_tokens: | ||
if not token: | ||
continue | ||
|
||
matched = False | ||
# Try to match against parameter patterns | ||
# We could specify an ordering but for now we just try all | ||
for pattern, processor in self.param_patterns.items(): | ||
match = pattern.fullmatch(token) | ||
if match: | ||
param_name, value = processor(match, quant_config) | ||
params[param_name] = value | ||
matched = True | ||
break | ||
|
||
if not matched: | ||
field_info = [ | ||
(field.name, field.type) | ||
for field in dataclasses.fields(quant_config) | ||
] | ||
raise ValueError( | ||
f"Unrecognized parameter token: {token} in {param_tokens}\nAvailable parameters for {quant_config.__name__}: {field_info}" | ||
) | ||
|
||
return params |