diff --git a/luxonis_ml/utils/registry.py b/luxonis_ml/utils/registry.py index 148c5a9e..49d4f9bd 100644 --- a/luxonis_ml/utils/registry.py +++ b/luxonis_ml/utils/registry.py @@ -1,5 +1,15 @@ +import warnings from abc import ABCMeta -from typing import Callable, Dict, Generic, Optional, Tuple, TypeVar, Union +from typing import ( + Callable, + Dict, + Generic, + Optional, + Tuple, + TypeVar, + Union, + overload, +) T = TypeVar("T", bound=type) @@ -9,7 +19,7 @@ def __init__(self, name: str): """A Registry class to store and retrieve modules. @type name: str - @param name: Name of the registry + @ivar name: Name of the registry """ self._module_dict: Dict[str, T] = {} self._name = name @@ -23,14 +33,16 @@ def __repr__(self): def __len__(self): return len(self._module_dict) + def __getitem__(self, key: str) -> T: + return self.get(key) + + def __setitem__(self, key: str, value: T) -> None: + self._register(module=value, module_name=key, force=True) + @property def name(self): return self._name - @property - def module_dict(self): - return self._module_dict - def get(self, key: str) -> T: """Retrieves the registry record for the key. @@ -43,16 +55,57 @@ def get(self, key: str) -> T: """ module_cls = self._module_dict.get(key, None) if module_cls is None: - raise KeyError(f"Class `{key}` not in the `{self.name}` registry.") + raise KeyError(f"'{key}' not in the '{self.name}' registry.") else: return module_cls + @overload + def register_module( + self, name: Optional[str] = ..., module: None = ..., force: bool = ... + ) -> Callable[[T], T]: ... + + @overload + def register_module( + self, name: Optional[str] = ..., module: T = ..., force: bool = ... + ) -> T: ... + def register_module( self, name: Optional[str] = None, module: Optional[T] = None, force: bool = False, ) -> Union[T, Callable[[T], T]]: + warnings.warn( + "`register_module` is deprecated, use `register` instead." + ) + + return self.register(name=name, module=module, force=force) + + @overload + def register( + self, + module: None = ..., + *, + name: Optional[str] = ..., + force: bool = ..., + ) -> Callable[[T], T]: ... + + @overload + def register( + self, + module: T = ..., + *, + name: Optional[str] = ..., + force: bool = ..., + ) -> None: ... + + def register( + self, + module: Optional[T] = None, + *, + name: Optional[str] = None, + force: bool = False, + ) -> Optional[Callable[[T], T]]: """Registers a module. Can be used as a decorator or as a normal method: @@ -86,19 +139,16 @@ def register_module( @raise KeyError: Raised if class name already exists and C{force==False} """ - # use it as a normal method: x.register_module(module=SomeClass) if module is not None: - self._register_module(module=module, module_name=name, force=force) - return module + return self._register(module=module, module_name=name, force=force) - # use it as a decorator: @x.register_module() - def _register(module: T) -> T: - self._register_module(module=module, module_name=name, force=force) + def wrapper(module: T) -> T: + self._register(module=module, module_name=name, force=force) return module - return _register + return wrapper - def _register_module( + def _register( self, module: T, module_name: Optional[str] = None, force: bool = False ) -> None: if module_name is None: @@ -107,7 +157,8 @@ def _register_module( if not force and module_name in self._module_dict: existed_module = self._module_dict[module_name] raise KeyError( - f"`{module_name}` already registred in `{self.name}` registry at `{existed_module.__module__}`." + f"`{module_name}` already registred in `{self.name}` " + f"registry at `{existed_module.__module__}`." ) self._module_dict[module_name] = module diff --git a/tests/test_utils/test_registry.py b/tests/test_utils/test_registry.py index 7be125b4..cbb7d8c4 100644 --- a/tests/test_utils/test_registry.py +++ b/tests/test_utils/test_registry.py @@ -11,40 +11,43 @@ def registry() -> Registry: def test_creation(): registry = Registry("test") assert registry.name == "test" - assert not registry.module_dict + assert not registry._module_dict assert len(registry) == 0 assert str(registry) == "Registry('test')" assert repr(registry) == "Registry('test')" def test_registry(registry: Registry): - assert not registry.module_dict + assert not registry._module_dict assert len(registry) == 0 class A: pass - registry.register_module(module=A) + registry.register(module=A) assert registry.get("A") is A - @registry.register_module() + @registry.register() class B: pass assert registry.get("B") is B assert len(registry) == 2 - registry.register_module(name="C", module=A) + registry.register(name="C", module=A) assert registry.get("C") is A assert len(registry) == 3 - registry.register_module(name="C", module=B, force=True) + registry.register(name="C", module=B, force=True) assert registry.get("C") is B + registry["D"] = A + assert registry["D"] is A + with pytest.raises(KeyError): - registry.register_module(name="C", module=A, force=False) + registry.register(name="C", module=A, force=False) - @registry.register_module(name="Foo") + @registry.register(name="Foo") class Bar: pass