Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Registry Type Fix #225

Merged
merged 3 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 67 additions & 16 deletions luxonis_ml/utils/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import warnings
from abc import ABCMeta
from typing import Callable, Dict, Generic, Optional, Tuple, TypeVar, Union
from typing import (
Callable,
Dict,
Generic,
Optional,
Tuple,
TypeVar,
Union,
overload,
)

T = TypeVar("T", bound=type)

Expand All @@ -9,7 +19,7 @@ def __init__(self, name: str):
"""A Registry class to store and retrieve modules.

@type name: str
@param name: Name of the registry
@ivar name: Name of the registry
"""
self._module_dict: Dict[str, T] = {}
self._name = name
Expand All @@ -23,14 +33,16 @@ def __repr__(self):
def __len__(self):
return len(self._module_dict)

def __getitem__(self, key: str) -> T:
return self.get(key)

def __setitem__(self, key: str, value: T) -> None:
self._register(module=value, module_name=key, force=True)

@property
def name(self):
return self._name

@property
def module_dict(self):
return self._module_dict

def get(self, key: str) -> T:
"""Retrieves the registry record for the key.

Expand All @@ -43,16 +55,57 @@ def get(self, key: str) -> T:
"""
module_cls = self._module_dict.get(key, None)
if module_cls is None:
raise KeyError(f"Class `{key}` not in the `{self.name}` registry.")
raise KeyError(f"'{key}' not in the '{self.name}' registry.")
else:
return module_cls

@overload
def register_module(
self, name: Optional[str] = ..., module: None = ..., force: bool = ...
) -> Callable[[T], T]: ...

@overload
def register_module(
self, name: Optional[str] = ..., module: T = ..., force: bool = ...
) -> T: ...

def register_module(
self,
name: Optional[str] = None,
module: Optional[T] = None,
force: bool = False,
) -> Union[T, Callable[[T], T]]:
warnings.warn(
"`register_module` is deprecated, use `register` instead."
)

return self.register(name=name, module=module, force=force)

@overload
def register(
self,
module: None = ...,
*,
name: Optional[str] = ...,
force: bool = ...,
) -> Callable[[T], T]: ...

@overload
def register(
self,
module: T = ...,
*,
name: Optional[str] = ...,
force: bool = ...,
) -> None: ...

def register(
self,
module: Optional[T] = None,
*,
name: Optional[str] = None,
force: bool = False,
) -> Optional[Callable[[T], T]]:
"""Registers a module.

Can be used as a decorator or as a normal method:
Expand Down Expand Up @@ -86,19 +139,16 @@ def register_module(
@raise KeyError: Raised if class name already exists and C{force==False}
"""

# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(module=module, module_name=name, force=force)
return module
return self._register(module=module, module_name=name, force=force)

# use it as a decorator: @x.register_module()
def _register(module: T) -> T:
self._register_module(module=module, module_name=name, force=force)
def wrapper(module: T) -> T:
self._register(module=module, module_name=name, force=force)
return module

return _register
return wrapper

def _register_module(
def _register(
self, module: T, module_name: Optional[str] = None, force: bool = False
) -> None:
if module_name is None:
Expand All @@ -107,7 +157,8 @@ def _register_module(
if not force and module_name in self._module_dict:
existed_module = self._module_dict[module_name]
raise KeyError(
f"`{module_name}` already registred in `{self.name}` registry at `{existed_module.__module__}`."
f"`{module_name}` already registred in `{self.name}` "
f"registry at `{existed_module.__module__}`."
)

self._module_dict[module_name] = module
Expand Down
19 changes: 11 additions & 8 deletions tests/test_utils/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,43 @@ def registry() -> Registry:
def test_creation():
registry = Registry("test")
assert registry.name == "test"
assert not registry.module_dict
assert not registry._module_dict
assert len(registry) == 0
assert str(registry) == "Registry('test')"
assert repr(registry) == "Registry('test')"


def test_registry(registry: Registry):
assert not registry.module_dict
assert not registry._module_dict
assert len(registry) == 0

class A:
pass

registry.register_module(module=A)
registry.register(module=A)
assert registry.get("A") is A

@registry.register_module()
@registry.register()
class B:
pass

assert registry.get("B") is B
assert len(registry) == 2

registry.register_module(name="C", module=A)
registry.register(name="C", module=A)
assert registry.get("C") is A
assert len(registry) == 3

registry.register_module(name="C", module=B, force=True)
registry.register(name="C", module=B, force=True)
assert registry.get("C") is B

registry["D"] = A
assert registry["D"] is A

with pytest.raises(KeyError):
registry.register_module(name="C", module=A, force=False)
registry.register(name="C", module=A, force=False)

@registry.register_module(name="Foo")
@registry.register(name="Foo")
class Bar:
pass

Expand Down
Loading