-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial stab at string based config parser #1774
Draft
drisspg
wants to merge
1
commit into
main
Choose a base branch
from
drisspg/stack/39
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+358
−0
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import unittest | ||
|
||
import torch | ||
|
||
from torchao.quantization.config_parser import ConfigParser | ||
from torchao.quantization.quant_api import ( | ||
Float8DynamicActivationFloat8WeightConfig, | ||
Float8WeightOnlyConfig, | ||
Int4DynamicActivationInt4WeightConfig, | ||
Int4WeightOnlyConfig, | ||
Int8DynamicActivationInt4WeightConfig, | ||
Int8DynamicActivationInt8WeightConfig, | ||
Int8WeightOnlyConfig, | ||
MappingType, | ||
UIntXWeightOnlyConfig, | ||
) | ||
|
||
|
||
class TestConfigParser(unittest.TestCase): | ||
def setUp(self): | ||
self.parser = ConfigParser() | ||
|
||
def test_int4wo_config(self): | ||
# Basic Int4WeightOnlyConfig | ||
config = self.parser.parse("int4wo_g32") | ||
self.assertIsInstance(config, Int4WeightOnlyConfig) | ||
self.assertEqual(config.group_size, 32) | ||
|
||
# With symmetry specified | ||
config = self.parser.parse("int4wo_g64") | ||
self.assertIsInstance(config, Int4WeightOnlyConfig) | ||
self.assertEqual(config.group_size, 64) | ||
|
||
def test_int8wo_config(self): | ||
# Basic Int8WeightOnlyConfig | ||
config = self.parser.parse("int8wo_g128") | ||
self.assertIsInstance(config, Int8WeightOnlyConfig) | ||
self.assertEqual(config.group_size, 128) | ||
|
||
# Verify that symmetry parameter raises error since not supported | ||
with self.assertRaises(ValueError) as context: | ||
self.parser.parse("int8wo_g128_sym") | ||
|
||
self.assertIn( | ||
"Invalid parameters for Int8WeightOnlyConfig", str(context.exception) | ||
) | ||
self.assertIn("mapping_type", str(context.exception)) | ||
|
||
def test_int8dqint4_config(self): | ||
# Int8 dynamic activation with Int4 weight | ||
config = self.parser.parse("int8dqint4_g32") | ||
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig) | ||
self.assertEqual(config.group_size, 32) | ||
|
||
# With symmetry | ||
config = self.parser.parse("int8dqint4_g32_sym") | ||
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig) | ||
self.assertEqual(config.group_size, 32) | ||
self.assertEqual(config.mapping_type, MappingType.SYMMETRIC) | ||
|
||
def test_int8dqint8_config(self): | ||
# Int8 dynamic activation with Int8 weight | ||
config = self.parser.parse("int8dqint8") | ||
self.assertIsInstance(config, Int8DynamicActivationInt8WeightConfig) | ||
|
||
def test_int4dqint4_config(self): | ||
# Int4 dynamic activation with Int4 weight | ||
config = self.parser.parse("int4dqint4_sym") | ||
self.assertIsInstance(config, Int4DynamicActivationInt4WeightConfig) | ||
self.assertEqual(config.mapping_type, MappingType.SYMMETRIC) | ||
|
||
def test_float8wo_config(self): | ||
# Basic Float8WeightOnlyConfig with e4m3 dtype | ||
config = self.parser.parse("float8wo_e4m3") | ||
self.assertIsInstance(config, Float8WeightOnlyConfig) | ||
self.assertEqual(config.weight_dtype, torch.float8_e4m3fn) | ||
|
||
# With e5m2 dtype | ||
config = self.parser.parse("float8wo_e5m2") | ||
self.assertIsInstance(config, Float8WeightOnlyConfig) | ||
self.assertEqual(config.weight_dtype, torch.float8_e5m2) | ||
|
||
def test_float8dqfloat8_config(self): | ||
# Float8 dynamic activation with Float8 weight | ||
config = self.parser.parse("float8dqfloat8_e4m3") | ||
self.assertIsInstance(config, Float8DynamicActivationFloat8WeightConfig) | ||
self.assertEqual(config.activation_dtype, torch.float8_e4m3fn) | ||
self.assertEqual(config.weight_dtype, torch.float8_e4m3fn) | ||
|
||
def test_uintxwo_config(self): | ||
# UIntX config with uint4 | ||
config = self.parser.parse("uintxwo_uint4_g32") | ||
self.assertIsInstance(config, UIntXWeightOnlyConfig) | ||
|
||
# With uint8 | ||
config = self.parser.parse("uintxwo_uint8_g64") | ||
self.assertIsInstance(config, UIntXWeightOnlyConfig) | ||
|
||
# def test_fpx_config(self): | ||
# # FPX config | ||
# config = self.parser.parse("fpx_e4m3") | ||
# self.assertIsInstance(config, FPXWeightOnlyConfig) | ||
# self.assertEqual(config.ebits, 4) | ||
# self.assertEqual(config.mbits, 3) | ||
|
||
def test_invalid_config_string(self): | ||
# Test empty string | ||
with self.assertRaises(ValueError): | ||
self.parser.parse("") | ||
|
||
# Test unknown base config | ||
with self.assertRaises(ValueError): | ||
self.parser.parse("unknown_config") | ||
|
||
# Test invalid parameter token | ||
with self.assertRaises(ValueError): | ||
self.parser.parse("int4wo_invalid_token") | ||
|
||
def test_complex_configurations(self): | ||
# Adjust tests for complex configurations to match actual parameter names | ||
config = self.parser.parse("int4wo_g32") | ||
self.assertIsInstance(config, Int4WeightOnlyConfig) | ||
self.assertEqual(config.group_size, 32) | ||
|
||
config = self.parser.parse("int8dqint4_g32_asym") | ||
self.assertIsInstance(config, Int8DynamicActivationInt4WeightConfig) | ||
self.assertEqual(config.group_size, 32) | ||
self.assertEqual(config.mapping_type, MappingType.ASYMMETRIC) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
import dataclasses | ||
import re | ||
from typing import Any, Dict, List, Protocol, Tuple, Type | ||
|
||
import torch | ||
|
||
from torchao.core.config import AOBaseConfig | ||
from torchao.quantization.quant_api import ( | ||
Float8DynamicActivationFloat8WeightConfig, | ||
Float8WeightOnlyConfig, | ||
Int4DynamicActivationInt4WeightConfig, | ||
Int4WeightOnlyConfig, | ||
Int8DynamicActivationInt4WeightConfig, | ||
Int8DynamicActivationInt8WeightConfig, | ||
Int8WeightOnlyConfig, | ||
MappingType, | ||
# Float8StaticActivationFloat8WeightConfig, | ||
UIntXWeightOnlyConfig, | ||
) | ||
|
||
# Create a type alias for AOBaseConfig classes | ||
ConfigType = Type[AOBaseConfig] | ||
|
||
|
||
# Define a Protocol for parameter processors | ||
class ParameterProcessor(Protocol): | ||
"""Protocol defining the interface for parameter processors""" | ||
|
||
def __call__(self, match: re.Match, quant_config: ConfigType) -> Tuple[str, Any]: | ||
""" | ||
Process a regex match into a parameter name and value | ||
|
||
Args: | ||
match: The regex match object containing captured groups | ||
quant_config: The quantization config class being instantiated | ||
|
||
Returns: | ||
Tuple of (parameter_name, parameter_value) | ||
|
||
Note: | ||
If you need special handling based on the quant_config type, | ||
be sure to use issubclass instead of isinstance. | ||
""" | ||
... | ||
|
||
|
||
def process_bits(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
return "bits", int(match.group(1)) | ||
|
||
|
||
def process_group_size(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
return "group_size", int(match.group(1)) | ||
|
||
|
||
def process_activation_bits( | ||
match: re.Match, quant_config: AOBaseConfig | ||
) -> Tuple[str, Any]: | ||
return "activation_bits", int(match.group(1)) | ||
|
||
|
||
def process_weight_bits(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
return "weight_bits", int(match.group(1)) | ||
|
||
|
||
def process_symmetry(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
mapping_type = ( | ||
MappingType.SYMMETRIC if match.group(1) == "sym" else MappingType.ASYMMETRIC | ||
) | ||
return "mapping_type", mapping_type | ||
|
||
|
||
def process_dtype(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
dtype_map = { | ||
"int4": torch.int4, | ||
"int8": torch.int8, | ||
"uint4": torch.uint4, | ||
"uint8": torch.uint8, | ||
"e4m3": torch.float8_e4m3fn, | ||
"e5m2": torch.float8_e5m2, | ||
} | ||
# The float8's have different key names :( | ||
key = ( | ||
"weight_dtype" | ||
if issubclass( | ||
quant_config, | ||
( | ||
Float8WeightOnlyConfig, | ||
Float8DynamicActivationFloat8WeightConfig, | ||
), | ||
) | ||
else "dtype" | ||
) | ||
return key, dtype_map[match.group(1)] | ||
|
||
|
||
def process_per_row(match: re.Match, quant_config: AOBaseConfig) -> Tuple[str, Any]: | ||
return "per_row", True | ||
|
||
|
||
class ConfigParser: | ||
"""Parser for string-based configuration patterns""" | ||
|
||
# Parameter patterns with their processing functions | ||
param_patterns: Dict[re.Pattern, ParameterProcessor] = { | ||
re.compile(r"(\d+)bit"): process_bits, | ||
re.compile(r"g(\d+)"): process_group_size, | ||
# re.compile(r"act(\d+)"): process_activation_bits, | ||
# re.compile(r"w(\d+)"): process_weight_bits, | ||
re.compile(r"(sym|asym)"): process_symmetry, | ||
re.compile(r"(int4|int8|uint4|uint8|e4m3|e5m2)"): process_dtype, | ||
re.compile(r"(per_row)"): process_per_row, | ||
} | ||
|
||
# Map from string prefix to QuantType | ||
type_mapping = { | ||
"int4wo": Int4WeightOnlyConfig, | ||
"int8wo": Int8WeightOnlyConfig, | ||
"int8dqint4": Int8DynamicActivationInt4WeightConfig, | ||
"int8dqint8": Int8DynamicActivationInt8WeightConfig, | ||
"int4dqint4": Int4DynamicActivationInt4WeightConfig, | ||
"float8wo": Float8WeightOnlyConfig, | ||
"float8dqfloat8": Float8DynamicActivationFloat8WeightConfig, | ||
# "float8staticfloat8": Float8StaticActivationFloat8WeightConfig, | ||
"uintxwo": UIntXWeightOnlyConfig, | ||
# "fpx": FPXWeightOnlyConfig, | ||
} | ||
|
||
def parse(self, config_str: str) -> AOBaseConfig: | ||
""" | ||
Parse a configuration string into an AO quantization configuration object. | ||
|
||
This is the main entrypoint for converting string-based configuration into actual config objects. | ||
The expected format is "base_param1-value1_param2-value2" where "base" identifies the base | ||
quantization type and subsequent tokens specify parameter values. | ||
|
||
Examples: | ||
config_parser.parse("int8dqint8") | ||
config_parser.parse("int8dqint4_g32") | ||
|
||
Args: | ||
config_str: String representation of the quantization configuration | ||
|
||
Returns: | ||
AOBaseConfig: Instantiated quantization configuration object | ||
|
||
Raises: | ||
ValueError: If the config string is empty or invalid | ||
""" | ||
tokens = config_str.split("_") | ||
|
||
if not tokens: | ||
raise ValueError("Empty config string") | ||
|
||
# The first token is the base quantization type | ||
quant_config = self._get_config(tokens[0]) | ||
|
||
# We know the base quant type, now we convert each token to its parameter | ||
params = self._extract_params(quant_config, tokens[1:]) | ||
|
||
return self._instantiate_config(quant_config, params) | ||
|
||
def _get_config(self, first_token: str) -> AOBaseConfig: | ||
"""Get the quantization config from a string""" | ||
try: | ||
quant_config = self.type_mapping[first_token] | ||
except KeyError: | ||
# Print available base configurations before raising error | ||
available_configs = list(self.type_mapping.keys()) | ||
raise ValueError( | ||
f"Unknown quantization type in string: {first_token} \n Available base configurations: {available_configs}" | ||
) | ||
return quant_config | ||
|
||
def _instantiate_config( | ||
self, quant_config: AOBaseConfig, params: Dict[str, Any] | ||
) -> AOBaseConfig: | ||
"""Sprinkle some extra logic for helping w/ instantiation failures""" | ||
try: | ||
return quant_config(**params) | ||
except TypeError as e: | ||
# Get proper field information for error message | ||
valid_fields = { | ||
field.name | ||
for field in dataclasses.fields(quant_config) | ||
if field.name != "self" | ||
} | ||
invalid_params = {k: v for k, v in params.items() if k not in valid_fields} | ||
|
||
field_info = [field.name for field in dataclasses.fields(quant_config)] | ||
|
||
raise ValueError( | ||
f"Invalid parameters for {quant_config.__name__}: {list(invalid_params.keys())}.\n" | ||
f"Available parameters for {quant_config.__name__}: {field_info}" | ||
) from e | ||
|
||
def _extract_params( | ||
self, quant_config: AOBaseConfig, param_tokens: List[str] | ||
) -> Dict[str, Any]: | ||
"""Extract parameters from tokens""" | ||
params = {} | ||
|
||
for token in param_tokens: | ||
if not token: | ||
continue | ||
|
||
matched = False | ||
# Try to match against parameter patterns | ||
# We could specify an ordering but for now we just try all | ||
for pattern, processor in self.param_patterns.items(): | ||
match = pattern.fullmatch(token) | ||
if match: | ||
param_name, value = processor(match, quant_config) | ||
params[param_name] = value | ||
matched = True | ||
break | ||
|
||
if not matched: | ||
field_info = [ | ||
(field.name, field.type) | ||
for field in dataclasses.fields(quant_config) | ||
] | ||
raise ValueError( | ||
f"Unrecognized parameter token: {token} in {param_tokens}\nAvailable parameters for {quant_config.__name__}: {field_info}" | ||
) | ||
|
||
return params |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, one more thing is can you add layout here as well:
ao/torchao/quantization/quant_api.py
Line 814 in f478692