Skip to content

Commit

Permalink
initial commit to bring pydantic to v2
Browse files Browse the repository at this point in the history
  • Loading branch information
roussel-ryan committed Oct 23, 2023
1 parent ca5e467 commit 03b77dd
Show file tree
Hide file tree
Showing 15 changed files with 170 additions and 139 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ channels:
- conda-forge
dependencies:
- python=3.9
- pydantic==1.10.9
- pydantic>2.3
- numpy
- pyyaml
131 changes: 98 additions & 33 deletions lume_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import yaml
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Union
from typing import Any, Callable, Union, TextIO
from types import FunctionType, MethodType

import numpy as np
from pydantic import BaseModel, validator
from pydantic import BaseModel, ConfigDict, field_validator, SerializeAsAny

from lume_model.variables import (
InputVariable,
OutputVariable,
OutputVariable, ScalarInputVariable, ScalarOutputVariable,
)
from lume_model.utils import (
try_import_module,
Expand All @@ -23,7 +23,6 @@

logger = logging.getLogger(__name__)


JSON_ENCODERS = {
# function/method type distinguished for class members and not recognized as callables
FunctionType: lambda x: f"{x.__module__}.{x.__qualname__}",
Expand Down Expand Up @@ -96,7 +95,7 @@ def process_keras_model(


def recursive_serialize(
v,
v: dict[str, Any],
base_key: str = "",
file_prefix: Union[str, os.PathLike] = "",
save_models: bool = True,
Expand All @@ -121,11 +120,13 @@ def recursive_serialize(
if isinstance(value, dict):
v[key] = recursive_serialize(value, key)
elif torch is not None and isinstance(value, torch.nn.Module):
v[key] = process_torch_module(value, base_key, key, file_prefix, save_models)
v[key] = process_torch_module(value, base_key, key, file_prefix,
save_models)
elif isinstance(value, list) and torch is not None and any(
isinstance(ele, torch.nn.Module) for ele in value):
v[key] = [
process_torch_module(value[i], base_key, f"{key}_{i}", file_prefix, save_models)
process_torch_module(value[i], base_key, f"{key}_{i}", file_prefix,
save_models)
for i in range(len(value))
]
elif keras is not None and isinstance(value, keras.Model):
Expand Down Expand Up @@ -164,7 +165,6 @@ def recursive_deserialize(v):
def json_dumps(
v,
*,
default,
base_key="",
file_prefix: Union[str, os.PathLike] = "",
save_models: bool = True,
Expand All @@ -181,8 +181,8 @@ def json_dumps(
Returns:
JSON formatted string.
"""
v = recursive_serialize(v, base_key, file_prefix, save_models)
v = json.dumps(v, default=default)
v = recursive_serialize(v.model_dump(), base_key, file_prefix, save_models)
v = json.dumps(v)
return v


Expand Down Expand Up @@ -232,7 +232,8 @@ def model_kwargs_from_dict(config: dict) -> dict:
"""
config = deserialize_variables(config)
if all(key in config.keys() for key in ["input_variables", "output_variables"]):
config["input_variables"], config["output_variables"] = variables_from_dict(config)
config["input_variables"], config["output_variables"] = variables_from_dict(
config)
_ = config.pop("model_class", None)
return config

Expand All @@ -247,15 +248,44 @@ class LUMEBaseModel(BaseModel, ABC):
input_variables: List defining the input variables and their order.
output_variables: List defining the output variables and their order.
"""
input_variables: list[InputVariable]
output_variables: list[OutputVariable]
input_variables: list[SerializeAsAny[InputVariable]]
output_variables: list[SerializeAsAny[OutputVariable]]

model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)

class Config:
extra = "allow"
json_dumps = json_dumps
json_loads = json_loads
validate_assignment = True
arbitrary_types_allowed = True
@field_validator("input_variables", mode="before")
def validate_input_variables(cls, value):
new_value = []
if isinstance(value, dict):
for name, val in value.items():
if isinstance(val, dict):
if val["variable_type"] == "scalar":
new_value.append(ScalarInputVariable(name=name, **val))
elif isinstance(val, InputVariable):
new_value.append(val)
else:
raise TypeError(f"type {type(val)} not supported")
elif isinstance(value, list):
new_value = value

return new_value

@field_validator("output_variables", mode="before")
def validate_output_variables(cls, value):
new_value = []
if isinstance(value, dict):
for name, val in value.items():
if isinstance(val, dict):
if val["variable_type"] == "scalar":
new_value.append(ScalarOutputVariable(name=name, **val))
elif isinstance(val, OutputVariable):
new_value.append(val)
else:
raise TypeError(f"type {type(val)} not supported")
elif isinstance(value, list):
new_value = value

return new_value

def __init__(
self,
Expand All @@ -274,7 +304,7 @@ def __init__(
else:
super().__init__(**kwargs)

@validator("input_variables", "output_variables")
@field_validator("input_variables", "output_variables")
def unique_variable_names(cls, value):
verify_unique_variable_names(value)
return value
Expand All @@ -291,12 +321,35 @@ def output_names(self) -> list[str]:
def evaluate(self, input_dict: dict[str, Any]) -> dict[str, Any]:
pass

def yaml(
def to_json(self, **kwargs) -> str:
return json_dumps(self, **kwargs)

def dict(self, **kwargs) -> dict[str, Any]:
config = super().model_dump(**kwargs)
return {"model_class": self.__class__.__name__} | config

def json(self, **kwargs) -> str:
result = self.to_json(**kwargs)
config = json.loads(result)
config = {"model_class": self.__class__.__name__} | config

return json.dumps(config)

def yaml(self, **kwargs):
"""serialize first then dump to yaml string"""
output = json.loads(
self.to_json(
**kwargs,
)
)
return yaml.dump(output, default_flow_style=None, sort_keys=False)

def dump(
self,
file: Union[str, os.PathLike] = None,
file: Union[str, os.PathLike],
save_models: bool = True,
base_key: str = "",
) -> str:
):
"""Returns and optionally saves YAML formatted string defining the model.
Args:
Expand All @@ -307,13 +360,25 @@ def yaml(
Returns:
YAML formatted string defining the model.
"""
file_prefix = ""
if file is not None:
file_prefix = os.path.splitext(file)[0]
config = json.loads(self.json(base_key=base_key, file_prefix=file_prefix, save_models=save_models))
s = yaml.dump({"model_class": self.__class__.__name__} | config,
default_flow_style=None, sort_keys=False)
if file is not None:
with open(file, "w") as f:
f.write(s)
return s
file_prefix = os.path.splitext(file)[0]

with open(file, "w") as f:
f.write(self.yaml(
base_key=base_key,
file_prefix=file_prefix,
save_models=save_models)
)

@classmethod
def from_file(cls, filename: str):
if not os.path.exists(filename):
raise OSError(f"file {filename} is not found")

with open(filename, "r") as file:
return cls.from_yaml(file)

@classmethod
def from_yaml(cls, yaml_obj: [str, TextIO]):
return cls.model_validate(yaml.safe_load(yaml_obj))


37 changes: 21 additions & 16 deletions lume_model/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from copy import deepcopy

import torch
from pydantic import validator
import yaml
from pydantic import validator, field_validator
from botorch.models.transforms.input import ReversibleInputTransform

from lume_model.base import LUMEBaseModel
Expand Down Expand Up @@ -43,19 +44,21 @@ class TorchModel(LUMEBaseModel):
device: Union[torch.device, str] = "cpu"
fixed_model: bool = True

def __init__(
self,
config: Union[dict, str] = None,
**kwargs,
):
"""Initializes TorchModel.
Args:
config: Model configuration as dictionary, YAML or JSON formatted string or file path. This overrides
all other arguments.
**kwargs: See class attributes.
def __init__(self, *args, **kwargs):
"""
Initialize Xopt.
"""
super().__init__(config, **kwargs)
if len(args) == 1:
if len(kwargs) > 0:
raise ValueError("cannot specify yaml string and kwargs for Xopt init")
super().__init__(**yaml.safe_load(args[0]))
elif len(args) > 1:
raise ValueError(
"arguments to Xopt must be either a single yaml string "
"or a keyword arguments passed directly to pydantic"
)
else:
super().__init__(**kwargs)

# set precision
self.model.to(dtype=self.dtype)
Expand All @@ -81,14 +84,16 @@ def dtype(self):
def _tkwargs(self):
return {"device": self.device, "dtype": self.dtype}

@validator("model", pre=True)
@field_validator("model", mode="before")
def validate_torch_model(cls, v):
if isinstance(v, (str, os.PathLike)):
if os.path.exists(v):
v = torch.load(v)
else:
raise ValueError(f"path {v} does not exist!!")
return v

@validator("input_transformers", "output_transformers", pre=True)
@field_validator("input_transformers", "output_transformers", mode="before")
def validate_botorch_transformers(cls, v):
if not isinstance(v, list):
raise ValueError("Transformers must be passed as list.")
Expand All @@ -102,7 +107,7 @@ def validate_botorch_transformers(cls, v):
v = loaded_transformers
return v

@validator("output_format")
@field_validator("output_format")
def validate_output_format(cls, v):
supported_formats = ["tensor", "variable", "raw"]
if v not in supported_formats:
Expand Down
2 changes: 1 addition & 1 deletion lume_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def verify_unique_variable_names(variables: Union[list[InputVariable], list[Outp
raise ValueError(f"{var_str} names {non_unique_names} are not unique.")


def serialize_variables(v):
def serialize_variables(v: dict):
"""Performs custom serialization for in- and output variables.
Args:
Expand Down
Loading

0 comments on commit 03b77dd

Please sign in to comment.