Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context feature #79

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from colt.builder import ColtBuilder
from colt.callback import ColtCallback, SkipCallback # noqa: F401
from colt.context import ColtContext # noqa: F401
from colt.default_registry import DefaultRegistry
from colt.error import ConfigurationError # noqa: F401
from colt.lazy import Lazy # noqa: F401
Expand Down
50 changes: 36 additions & 14 deletions colt/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class UnionType: ...


from colt.callback import ColtCallback, MultiCallback, SkipCallback
from colt.context import ColtContext
from colt.default_registry import DefaultRegistry
from colt.error import ConfigurationError
from colt.lazy import Lazy
Expand Down Expand Up @@ -97,22 +98,25 @@ def __call__(
config: Any,
cls: Optional[Union[Type[T], Callable[..., T]]] = None,
) -> Union[T, Any]:
context = ColtContext(config=config)
if self._callback is not None:
with suppress(SkipCallback):
config = self._callback.on_start(self, config, cls)
return self._build(config, (), cls)
config = self._callback.on_start(config, self, context, cls)
return self._build(config, (), cls, context=context)

def dry_run(
self,
config: Any,
cls: Optional[Union[Type[T], Callable[..., T]]] = None,
*,
path: ParamPath = (),
context: Optional[ColtContext] = None,
) -> Union[T, Any]:
context = context or ColtContext(config=config)
if self._callback is not None:
with suppress(SkipCallback):
config = self._callback.on_start(self, config, cls)
return self._build(config, path, cls, skip_construction=True)
config = self._callback.on_start(config, self, context, cls)
return self._build(config, path, cls, context=context, skip_construction=True)

@staticmethod
def _get_constructor_by_name(
Expand Down Expand Up @@ -152,11 +156,6 @@ def _is_namedtuple(cls: Any) -> bool:
return False
return all(type(name) is str for name in fields)

@staticmethod
def _catname(parent: str, *keys: Union[int, str]) -> str:
key = ".".join(str(x) for x in keys)
return f"{parent}.{key}" if parent else key

def _get_constructor(
self,
config: Any,
Expand All @@ -182,6 +181,8 @@ def _construct_args(
constructor: Callable[..., T],
config: Mapping[str, Any],
path: ParamPath,
*,
context: ColtContext,
skip_construction: bool = False,
) -> Tuple[List[Any], Dict[str, Any]]:
if not config:
Expand All @@ -200,6 +201,7 @@ def _construct_args(
self._build(
val,
path + (self._argskey, i),
context=context,
skip_construction=skip_construction,
)
for i, val in enumerate(args_config)
Expand All @@ -223,6 +225,7 @@ def _construct_args(
val,
path + (key,),
type_hints.get(key),
context=context,
skip_construction=skip_construction,
)
for key, val in config.items()
Expand All @@ -236,12 +239,15 @@ def _build(
path: ParamPath,
annotation: Optional[Union[Type[T], Callable[..., T], Any]] = None,
*,
context: ColtContext,
raise_configuration_error: bool = True,
skip_construction: bool = False,
) -> Union[T, Any]:
if self._callback is not None:
with suppress(SkipCallback):
config = self._callback.on_build(self, config, path, annotation)
config = self._callback.on_build(
path, config, self, context, annotation
)

if annotation is not None and isinstance(annotation, type):
annotation = remove_optional(annotation)
Expand Down Expand Up @@ -278,6 +284,7 @@ def _build(
x,
path + (i,),
value_cls,
context=context,
skip_construction=skip_construction,
)
for i, x in enumerate(config)
Expand All @@ -290,6 +297,7 @@ def _build(
x,
path + (i,),
value_cls,
context=context,
skip_construction=skip_construction,
)
for i, x in enumerate(config)
Expand All @@ -301,6 +309,7 @@ def _build(
self._build(
x,
path + (i,),
context=context,
skip_construction=skip_construction,
)
for i, x in enumerate(config)
Expand All @@ -312,6 +321,7 @@ def _build(
x,
path + (i,),
args[0],
context=context,
skip_construction=skip_construction,
)
for i, x in enumerate(config)
Expand All @@ -324,7 +334,7 @@ def _build(
)

return tuple(
self._build(value_config, path + (i,), value_cls)
self._build(value_config, path + (i,), value_cls, context=context)
for i, (value_config, value_cls) in enumerate(zip(config, args))
)

Expand All @@ -336,11 +346,13 @@ def _build(
key_config,
path + (f"[key:{i}]",),
key_cls,
context=context,
skip_construction=skip_construction,
): self._build(
value_config,
path + (key_config,),
value_cls,
context=context,
skip_construction=skip_construction,
)
for i, (key_config, value_config) in enumerate(config.items())
Expand All @@ -360,6 +372,7 @@ def _build(
value_config,
path + (key,),
type_hints.get(key),
context=context,
skip_construction=skip_construction,
)
for key, value_config in config.items()
Expand All @@ -370,7 +383,9 @@ def _build(

if origin in (Union, UnionType):
if not args:
return self._build(config, path, skip_construction=skip_construction)
return self._build(
config, path, context=context, skip_construction=skip_construction
)

trial_exceptions: List[Tuple[Any, Exception, str]] = []
for value_cls in args:
Expand All @@ -379,6 +394,7 @@ def _build(
config,
path,
value_cls,
context=context,
raise_configuration_error=False,
skip_construction=skip_construction,
)
Expand All @@ -401,7 +417,7 @@ def _build(

if origin == Lazy:
value_cls = args[0] if args else None
return Lazy(config, path, value_cls, self)
return Lazy(config, path, context, value_cls, self)

if isinstance(config, (list, set, tuple)):
if origin is not None and not isinstance(config, origin):
Expand All @@ -421,6 +437,7 @@ def _build(
x,
path + (i,),
value_cls,
context=context,
skip_construction=skip_construction,
)
for i, x in enumerate(config)
Expand All @@ -444,6 +461,7 @@ def _build(
key: self._build(
val,
path + (key,),
context=context,
skip_construction=skip_construction,
)
for key, val in config.items()
Expand Down Expand Up @@ -472,7 +490,11 @@ def _build(
)

args_for_constructor, kwargs_for_constructor = self._construct_args(
constructor, config, path, skip_construction=skip_construction
constructor,
config,
path,
context=context,
skip_construction=skip_construction,
)

if skip_construction:
Expand Down
27 changes: 16 additions & 11 deletions colt/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

if typing.TYPE_CHECKING:
from colt.builder import ColtBuilder, ParamPath
from colt.context import ColtContext


T = TypeVar("T")
Expand All @@ -15,18 +16,20 @@ class SkipCallback(Exception): ...
class ColtCallback:
def on_start(
self,
builder: "ColtBuilder",
config: Any,
cls: Optional[Union[Type[T], Callable[..., T]]] = None,
builder: "ColtBuilder",
context: "ColtContext",
annotation: Optional[Union[Type[T], Callable[..., T]]] = None,
) -> Any:
del builder, cls
del builder, annotation
return config

def on_build(
self,
builder: "ColtBuilder",
config: Any,
path: "ParamPath",
config: Any,
builder: "ColtBuilder",
context: "ColtContext",
annotation: Optional[Union[Type[T], Callable[..., T]]] = None,
) -> Any:
raise SkipCallback
Expand All @@ -38,23 +41,25 @@ def __init__(self, *callbacks: ColtCallback) -> None:

def on_start(
self,
builder: "ColtBuilder",
config: Any,
cls: Optional[Union[Type[T], Callable[..., T]]] = None,
builder: "ColtBuilder",
context: "ColtContext",
annotation: Optional[Union[Type[T], Callable[..., T]]] = None,
) -> Any:
for callback in self.callbacks:
with suppress(SkipCallback):
config = callback.on_start(builder, config, cls)
config = callback.on_start(config, builder, context, annotation)
return config

def on_build(
self,
builder: "ColtBuilder",
config: Any,
path: "ParamPath",
config: Any,
builder: "ColtBuilder",
context: "ColtContext",
annotation: Optional[Union[Type[T], Callable[..., T]]] = None,
) -> Any:
for callback in self.callbacks:
with suppress(SkipCallback):
config = callback.on_build(builder, config, path, annotation)
config = callback.on_build(path, config, builder, context, annotation)
return config
15 changes: 15 additions & 0 deletions colt/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import dataclasses
import typing
from typing import Any, Dict, Type

if typing.TYPE_CHECKING:
from colt.callback import ColtCallback


CallbackState = Dict[Type["ColtCallback"], Dict[str, Any]]


@dataclasses.dataclass
class ColtContext:
config: Any
state: Dict[str, Any] = dataclasses.field(default_factory=dict)
21 changes: 18 additions & 3 deletions colt/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

if typing.TYPE_CHECKING:
from colt.builder import ColtBuilder, ParamPath
from colt.context import ColtContext

T = TypeVar("T")

Expand All @@ -24,7 +25,8 @@ class Lazy(Generic[T]):
def __init__(
self,
config: Any,
path: "ParamPath" = (),
path: "ParamPath",
context: "ColtContext",
cls: Optional[Type[T]] = None,
builder: Optional["ColtBuilder"] = None,
) -> None:
Expand All @@ -33,14 +35,25 @@ def __init__(
self._cls = cls
self._config = config or {}
self._path = path
self._context = context
self._builder = builder or ColtBuilder()

self._builder.dry_run(self._config, self._cls, path=self._path)
self._builder.dry_run(
self._config, self._cls, path=self._path, context=self._context
)

@property
def config(self) -> Any:
return self._config

@property
def path(self) -> "ParamPath":
return self._path

@property
def builder(self) -> "ColtBuilder":
return self._builder

@property
def constructor(self) -> Optional[Union[Type[T], Callable[..., T]]]:
return (
Expand Down Expand Up @@ -74,4 +87,6 @@ def construct(
update_field(config, k, v)
else:
config = self._config
return self._builder._build(config, self._path, self._cls)
return self._builder._build(
config, self._path, self._cls, context=self._context
)
17 changes: 10 additions & 7 deletions tests/test_callback.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import dataclasses
from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Optional, Type, TypeVar, Union

from colt import ColtBuilder, ColtCallback, SkipCallback
from colt import ColtBuilder, ColtCallback, ColtContext, SkipCallback
from colt.types import ParamPath

T = TypeVar("T")

Expand All @@ -15,11 +16,12 @@ class Foo:
class AddName(ColtCallback):
def on_start(
self,
builder: "ColtBuilder",
config: Any,
cls: Optional[Union[Type[T], Callable[..., T]]] = None,
builder: "ColtBuilder",
context: "ColtContext",
annotation: Optional[Union[Type[T], Callable[..., T]]] = None,
) -> Any:
if cls is Foo and isinstance(config, dict) and "name" not in config:
if annotation is Foo and isinstance(config, dict) and "name" not in config:
return {**config, "name": "foo"}
raise SkipCallback

Expand All @@ -42,9 +44,10 @@ class Foo:
class AddOne(ColtCallback):
def on_build(
self,
builder: ColtBuilder,
path: "ParamPath",
config: Any,
path: Tuple[Union[int, str], ...],
builder: "ColtBuilder",
context: "ColtContext",
annotation: Optional[Union[Type[T], Callable[..., T]]] = None,
) -> Any:
del builder, path
Expand Down
Loading