Skip to content

Commit

Permalink
Merge pull request #70 from altescy/lazy-update
Browse files Browse the repository at this point in the history
Make lazy updatable
  • Loading branch information
altescy authored Sep 19, 2024
2 parents ca45323 + dbe04e9 commit 4f0fb17
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 6 deletions.
41 changes: 38 additions & 3 deletions colt/lazy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import typing
from typing import Any, Generic, Optional, Type, TypeVar
from copy import deepcopy
from typing import Any, Generic, Mapping, Optional, Sequence, Type, TypeVar, Union

from colt.utils import update_field

if typing.TYPE_CHECKING:
from colt.builder import ColtBuilder
Expand All @@ -24,6 +27,38 @@ def __init__(

self._builder.dry_run(self._config, self._cls, param_name=self._param_name)

def construct(self, **kwargs: Any) -> T:
config = {**self._config, **kwargs}
@property
def config(self) -> Any:
return self._config

@property
def constructor(self) -> Optional[Type[T]]:
return self._cls

def update(
self,
*args: Mapping[Union[int, str, Sequence[Union[int, str]]], Any],
**kwargs: Any,
) -> None:
for arg in args:
for field, value in arg.items():
update_field(self._config, field, value)
for k, v in kwargs.items():
update_field(self._config, k, v)
self._builder.dry_run(self._config, self._cls, param_name=self._param_name)

def construct(
self,
*args: Mapping[Union[int, str, Sequence[Union[int, str]]], Any],
**kwargs: Any,
) -> T:
if args or kwargs:
config = deepcopy(self._config)
for arg in args:
for field, value in arg.items():
update_field(config, field, value)
for k, v in kwargs.items():
update_field(config, k, v)
else:
config = self._config
return self._builder._build(config, self._param_name, self._cls)
39 changes: 37 additions & 2 deletions colt/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib
import pkgutil
import sys
import typing as tp
from typing import Any, Dict, List, Sequence, Union


def import_submodules(package_name: str) -> None:
Expand All @@ -25,7 +25,7 @@ def import_submodules(package_name: str) -> None:
import_submodules(subpackage)


def import_modules(module_names: tp.List[str]) -> None:
def import_modules(module_names: List[str]) -> None:
"""
This method import modules recursively.
You should call this method to register your classes
Expand All @@ -38,3 +38,38 @@ def import_modules(module_names: tp.List[str]) -> None:
def indent(s: str, level: int = 1) -> str:
tabs = "\t" * level
return tabs + s.replace("\n", f"\n{tabs}")


def update_field(
obj: Union[Dict[Union[int, str], Any], List[Any]],
field: Union[int, str, Sequence[Union[int, str]]],
value: Any,
) -> None:
path: Sequence[Union[int, str]]
if isinstance(field, str):
path = field.split(".")
elif isinstance(field, int):
path = (field,)
else:
path = field
if len(path) == 1:
target_field = path[0]
if isinstance(obj, dict):
obj[target_field] = value
elif isinstance(obj, list):
if target_field == "+":
obj.append(value)
else:
target_field = int(target_field)
obj[target_field] = value
else:
raise ValueError("obj must be dict or list")
else:
target_field = path[0]
if isinstance(obj, dict):
update_field(obj[target_field], path[1:], value)
elif isinstance(obj, list):
target_field = int(target_field)
update_field(obj[target_field], path[1:], value)
else:
raise ValueError("obj must be dict or list")
28 changes: 27 additions & 1 deletion tests/test_lazy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses

import pytest

import colt
from colt import Lazy
from colt import ConfigurationError, Lazy


def test_lazy() -> None:
Expand All @@ -23,3 +25,27 @@ class Bar:
assert isinstance(foo, Foo)
assert foo.x == "hello"
assert foo.y == 10


def test_lazy_update() -> None:
@dataclasses.dataclass
class Foo:
name: str

@dataclasses.dataclass
class Bar:
foo: Lazy[Foo]

bar = colt.build({"foo": {"name": "foo"}}, Bar)

assert isinstance(bar, Bar)
assert isinstance(bar.foo, Lazy)

bar.foo.update(name="bar")
assert bar.foo.config == {"name": "bar"}

bar.foo.update({"name": "baz"})
assert bar.foo.config == {"name": "baz"}

with pytest.raises(ConfigurationError):
bar.foo.update(name=123)
28 changes: 28 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Dict, List, Union

import pytest

from colt.utils import update_field


@pytest.mark.parametrize(
"obj, field, value, expected",
[
({"a": 1}, "a", 2, {"a": 2}),
({0: 1}, 0, 2, {0: 2}),
({"a": {"b": 1}}, "a.b", 2, {"a": {"b": 2}}),
({"a": [1]}, "a.0", 2, {"a": [2]}),
({"a": 1}, "b", 2, {"a": 1, "b": 2}),
({"a": [1]}, "a.+", 2, {"a": [1, 2]}),
({"a": [1, {"b": 1}]}, "a.1.b", 2, {"a": [1, {"b": 2}]}),
({"a": {"b": [1]}}, ("a", "b", 0), 2, {"a": {"b": [2]}}),
],
)
def test_update_field(
obj: Union[Dict[Union[int, str], Any], List[Any]],
field: Union[int, str, List[Union[int, str]]],
value: Any,
expected: Union[Dict, List],
) -> None:
update_field(obj, field, value)
assert obj == expected

0 comments on commit 4f0fb17

Please sign in to comment.