Skip to content

Commit

Permalink
Update pydantic support to v2 (#2248)
Browse files Browse the repository at this point in the history
* Update setup.py

* Bump-Pydantic

* Fix typing, and model_config_dicts

* Fix all tests except recipe serialization

* Fix recipe serialization
Remove stale/misleading doctest
  • Loading branch information
rahul-tuli authored Apr 23, 2024
1 parent 1abc30c commit 82e204a
Show file tree
Hide file tree
Showing 18 changed files with 125 additions and 84 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"pandas>=0.25.0",
"packaging>=20.0",
"psutil>=5.0.0",
"pydantic>=1.8.2,<2.0.0",
"pydantic>=2.0.0,<2.8.0",
"requests>=2.0.0",
"scikit-learn>=0.24.2",
"scipy<1.9.2,>=1.8; python_version <= '3.9'",
Expand Down
6 changes: 3 additions & 3 deletions src/sparseml/core/modifier/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class Modifier(BaseModel, ModifierInterface, MultiFrameworkObject):
:param update: The update step for the modifier
"""

index: int = None
group: str = None
start: float = None
index: Optional[int] = None
group: Optional[str] = None
start: Optional[float] = None
end: Optional[float] = None
update: Optional[float] = None

Expand Down
8 changes: 5 additions & 3 deletions src/sparseml/core/recipe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Optional

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict

from sparseml.core.framework import Framework
from sparseml.core.recipe.args import RecipeArgs
Expand All @@ -36,6 +36,8 @@ class RecipeBase(BaseModel, ABC):
- create_modifier
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

@abstractmethod
def calculate_start(self) -> int:
raise NotImplementedError()
Expand All @@ -45,7 +47,7 @@ def calculate_end(self) -> int:
raise NotImplementedError()

@abstractmethod
def evaluate(self, args: RecipeArgs = None, shift: int = None):
def evaluate(self, args: Optional[RecipeArgs] = None, shift: Optional[int] = None):
raise NotImplementedError()

@abstractmethod
Expand Down
5 changes: 3 additions & 2 deletions src/sparseml/core/recipe/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import Any, Dict, Optional

from pydantic import root_validator
from pydantic import model_validator

from sparseml.core.factory import ModifierFactory
from sparseml.core.framework import Framework
Expand Down Expand Up @@ -99,7 +99,8 @@ def create_modifier(self, framework: Framework) -> "Modifier":
**self.args_evaluated,
)

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def extract_modifier_type(cls, values: Dict[str, Any]) -> Dict[str, Any]:
modifier = {"group": values.pop("group")}
assert len(values) == 1, "multiple key pairs found for modifier"
Expand Down
122 changes: 77 additions & 45 deletions src/sparseml/core/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any, Dict, List, Optional, Union

import yaml
from pydantic import Field, root_validator
from pydantic import Field, model_validator

from sparseml.core.framework import Framework
from sparseml.core.modifier import StageModifiers
Expand Down Expand Up @@ -152,7 +152,7 @@ def create_instance(
)
_LOGGER.debug(f"Input string: {path_or_modifiers}")
obj = _load_json_or_yaml_string(path_or_modifiers)
return Recipe.parse_obj(obj)
return Recipe.model_validate(obj)
else:
_LOGGER.info(f"Loading recipe from file {path_or_modifiers}")

Expand All @@ -174,7 +174,7 @@ def create_instance(
raise ValueError(
f"Could not parse recipe from path {path_or_modifiers}"
)
return Recipe.parse_obj(obj)
return Recipe.model_validate(obj)

@staticmethod
def simplify_recipe(
Expand Down Expand Up @@ -391,7 +391,8 @@ def create_modifier(self, framework: Framework) -> List["StageModifiers"]:

return modifiers

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def remap_stages(cls, values: Dict[str, Any]) -> Dict[str, Any]:
stages = []

Expand Down Expand Up @@ -515,25 +516,9 @@ def combine_metadata(self, metadata: Optional[RecipeMetaData]):

def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
>>> recipe_str = '''
... test_stage:
... pruning_modifiers:
... ConstantPruningModifier:
... start: 0.0
... end: 2.0
... targets: ['re:.*weight']
... '''
>>> recipe = Recipe.create_instance(recipe_str)
>>> recipe_dict = recipe.dict()
>>> stage = recipe_dict["stages"]["test"]
>>> pruning_mods = stage[0]['modifiers']['pruning']
>>> modifier_args = pruning_mods[0]['ConstantPruningModifier']
>>> modifier_args == {'start': 0.0, 'end': 2.0, 'targets': ['re:.*weight']}
True
:return: A dictionary representation of the recipe
"""
dict_ = super().dict(*args, **kwargs)
dict_ = super().model_dump(*args, **kwargs)
stages = {}

for stage in dict_["stages"]:
Expand Down Expand Up @@ -577,36 +562,34 @@ def _get_yaml_dict(self) -> Dict[str, Any]:
"""
Get a dictionary representation of the recipe for yaml serialization
The returned dict will only contain information necessary for yaml
serialization (ignores metadata, version, etc), and must not be used
in place of the dict method
serialization and must not be used in place of the dict method
:return: A dictionary representation of the recipe for yaml serialization
"""

def _modifier_group_to_dict(modifier_group: List[Dict[str, Any]]):
# convert a list of modifiers to a dict of modifiers
return {
key: value
for modifier in modifier_group
for key, value in modifier.items()
}
original_recipe_dict = self.dict()
yaml_recipe_dict = {}

def _stage_to_dict(stage: Dict[str, Any]):
# convert a stage to a dict of modifiers
return {
modifier_group_name: _modifier_group_to_dict(modifier_group)
for modifier_group_name, modifier_group in stage["modifiers"].items()
}
# populate recipe level attributes
recipe_level_attributes = ["version", "args", "metadata"]

final_dict = {}
for stage_name, stages in self.dict()["stages"].items():
if len(stages) == 1:
final_dict[stage_name] = _stage_to_dict(stages[0])
else:
for idx, stage in enumerate(stages):
final_dict[stage_name + "_" + str(idx)] = _stage_to_dict(stage)
for attribute in recipe_level_attributes:
if attribute_value := original_recipe_dict.get(attribute):
yaml_recipe_dict[attribute] = attribute_value

# populate stages
stages = original_recipe_dict["stages"]
for stage_name, stage_list in stages.items():
# stage is always a list of size 1
stage = stage_list[0]
stage_dict = get_yaml_serializable_stage_dict(modifiers=stage["modifiers"])

return final_dict
# infer run_type from stage
if run_type := stage.get("run_type"):
stage_dict["run_type"] = run_type

yaml_recipe_dict[stage_name] = stage_dict
return yaml_recipe_dict


@dataclass
Expand Down Expand Up @@ -704,9 +687,58 @@ def create_recipe_string_from_modifiers(
recipe_dict = {
f"{modifier_group_name}_stage": {
f"{default_group_name}_modifiers": {
modifier.__class__.__name__: modifier.dict() for modifier in modifiers
modifier.__class__.__name__: modifier.model_dump()
for modifier in modifiers
}
}
}
recipe_str: str = yaml.dump(recipe_dict)
return recipe_str


def get_modifiers_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]:

group_dict = {}

for modifier in modifiers:
modifier_type = modifier["type"]
modifier_group = modifier["group"]

if modifier_group not in group_dict:
group_dict[modifier_group] = []

modifier_dict = {modifier_type: modifier["args"]}
group_dict[modifier_group].append(modifier_dict)

return group_dict


def get_yaml_serializable_stage_dict(modifiers: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
This function is used to convert a list of modifiers into a dictionary
where the keys are the group names and the values are the modifiers
which in turn are dictionaries with the modifier type as the key and
the modifier args as the value.
This is needed to conform to our recipe structure during yaml serialization
where each stage, modifier_groups, and modifiers are represented as
valid yaml dictionaries.
Note: This function assumes that modifier groups do not contain the same
modifier type more than once in a group. This assumption is also held by
Recipe.create_instance(...) method.
:param modifiers: A list of dictionaries where each dictionary
holds all information about a modifier
:return: A dictionary where the keys are the group names and the values
are the modifiers which in turn are dictionaries with the modifier
type as the key and the modifier args as the value.
"""
stage_dict = {}
for modifier in modifiers:
group_name = f"{modifier['group']}_modifiers"
modifier_type = modifier["type"]
if group_name not in stage_dict:
stage_dict[group_name] = {}
stage_dict[group_name][modifier_type] = modifier["args"]
return stage_dict
7 changes: 5 additions & 2 deletions src/sparseml/core/recipe/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional

from pydantic import Field, root_validator
from pydantic import ConfigDict, Field, model_validator

from sparseml.core.framework import Framework
from sparseml.core.modifier import StageModifiers
Expand Down Expand Up @@ -46,6 +46,8 @@ class RecipeStage(RecipeBase):
:param args_evaluated: the evaluated RecipeArgs for the stage
"""

model_config = ConfigDict(arbitrary_types_allowed=True)

group: Optional[str] = None
run_type: Optional[StageRunType] = None
args: Optional[RecipeArgs] = None
Expand Down Expand Up @@ -139,7 +141,8 @@ def create_modifier(

return stage_modifiers

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def remap_modifiers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
modifiers = RecipeStage.extract_dict_modifiers(values)
values["modifiers"] = modifiers
Expand Down
13 changes: 6 additions & 7 deletions src/sparseml/exporters/transforms/kv_cache/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from sparseml.exporters.transforms import OnnxTransform
from sparseml.exporters.transforms.kv_cache.transforms_codegen import (
Expand Down Expand Up @@ -47,8 +47,9 @@ class KeyValueCacheConfig(BaseModel):
additional_transforms: Union[
List[Type[OnnxTransform]], Type[OnnxTransform], None
] = Field(
None,
description="A transform class (or list thereof) to use for additional "
"transforms to the model required for finalizing the kv cache injection."
"transforms to the model required for finalizing the kv cache injection.",
)
key_num_attention_heads: str = Field(
description="The key to use to get the number of attention heads from the "
Expand All @@ -59,10 +60,10 @@ class KeyValueCacheConfig(BaseModel):
"from the transformer's `config.json` file."
)
num_attention_heads: Optional[int] = Field(
description="The number of attention heads."
None, description="The number of attention heads."
)
hidden_size_kv_cache: Optional[int] = Field(
description="The hidden size of the key/value cache. "
None, description="The hidden size of the key/value cache. "
)
multiply_batch_by_num_att_heads: bool = Field(
default=False,
Expand All @@ -83,9 +84,7 @@ class KeyValueCacheConfig(BaseModel):
"the kv cache. If this is not provided, no transpose will "
"be applied.",
)

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)


OPT_CONFIG = KeyValueCacheConfig(
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/framework/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,13 @@ def save_framework_info(framework: Any, path: Optional[str] = None):
create_parent_dirs(path)

with open(path, "w") as file:
file.write(info.json())
file.write(info.model_dump_json())

_LOGGER.info(
"saved framework info for framework %s in file at %s", framework, path
),
else:
print(info.json(indent=4))
print(info.model_dump_json(indent=4))
_LOGGER.info("printed out framework info for framework %s", framework)


Expand Down
2 changes: 2 additions & 0 deletions src/sparseml/integration_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
default=export_model,
)
apply_optimizations: Optional[Callable[[Any], None]] = Field(
None,
description="A function that takes:"
" - path to the exported model"
" - names of the optimizations to apply"
Expand Down Expand Up @@ -223,6 +224,7 @@ class IntegrationHelperFunctions(RegistryMixin, BaseModel):
)

deployment_directory_files_optional: Optional[List[str]] = Field(
None,
description="A list that describes the "
"optional expected files of the deployment directory",
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import torch
from packaging import version
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator
from torch.nn import Identity


Expand Down Expand Up @@ -121,7 +121,8 @@ def get_observer(self) -> "torch.quantization.FakeQuantize":
qconfig_kwargs=self.kwargs,
)

@validator("strategy")
@field_validator("strategy")
@classmethod
def validate_strategy(cls, value):
valid_scopes = ["tensor", "channel"]
if value not in valid_scopes:
Expand Down Expand Up @@ -263,7 +264,7 @@ def __str__(self) -> str:
"""
:return: YAML friendly string serialization
"""
dict_repr = self.dict()
dict_repr = self.model_dump()
dict_repr = {
key: val if val is not None else "null" for key, val in dict_repr.items()
}
Expand Down
Loading

0 comments on commit 82e204a

Please sign in to comment.