Skip to content

Commit

Permalink
Registry Type Fix (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 authored Jan 16, 2025
1 parent c08bd5b commit 466baab
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 24 deletions.
83 changes: 67 additions & 16 deletions luxonis_ml/utils/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import warnings
from abc import ABCMeta
from typing import Callable, Dict, Generic, Optional, Tuple, TypeVar, Union
from typing import (
Callable,
Dict,
Generic,
Optional,
Tuple,
TypeVar,
Union,
overload,
)

T = TypeVar("T", bound=type)

Expand All @@ -9,7 +19,7 @@ def __init__(self, name: str):
"""A Registry class to store and retrieve modules.
@type name: str
@param name: Name of the registry
@ivar name: Name of the registry
"""
self._module_dict: Dict[str, T] = {}
self._name = name
Expand All @@ -23,14 +33,16 @@ def __repr__(self):
def __len__(self):
return len(self._module_dict)

def __getitem__(self, key: str) -> T:
return self.get(key)

def __setitem__(self, key: str, value: T) -> None:
self._register(module=value, module_name=key, force=True)

@property
def name(self):
return self._name

@property
def module_dict(self):
return self._module_dict

def get(self, key: str) -> T:
"""Retrieves the registry record for the key.
Expand All @@ -43,16 +55,57 @@ def get(self, key: str) -> T:
"""
module_cls = self._module_dict.get(key, None)
if module_cls is None:
raise KeyError(f"Class `{key}` not in the `{self.name}` registry.")
raise KeyError(f"'{key}' not in the '{self.name}' registry.")
else:
return module_cls

@overload
def register_module(
self, name: Optional[str] = ..., module: None = ..., force: bool = ...
) -> Callable[[T], T]: ...

@overload
def register_module(
self, name: Optional[str] = ..., module: T = ..., force: bool = ...
) -> T: ...

def register_module(
self,
name: Optional[str] = None,
module: Optional[T] = None,
force: bool = False,
) -> Union[T, Callable[[T], T]]:
warnings.warn(
"`register_module` is deprecated, use `register` instead."
)

return self.register(name=name, module=module, force=force)

@overload
def register(
self,
module: None = ...,
*,
name: Optional[str] = ...,
force: bool = ...,
) -> Callable[[T], T]: ...

@overload
def register(
self,
module: T = ...,
*,
name: Optional[str] = ...,
force: bool = ...,
) -> None: ...

def register(
self,
module: Optional[T] = None,
*,
name: Optional[str] = None,
force: bool = False,
) -> Optional[Callable[[T], T]]:
"""Registers a module.
Can be used as a decorator or as a normal method:
Expand Down Expand Up @@ -86,19 +139,16 @@ def register_module(
@raise KeyError: Raised if class name already exists and C{force==False}
"""

# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(module=module, module_name=name, force=force)
return module
return self._register(module=module, module_name=name, force=force)

# use it as a decorator: @x.register_module()
def _register(module: T) -> T:
self._register_module(module=module, module_name=name, force=force)
def wrapper(module: T) -> T:
self._register(module=module, module_name=name, force=force)
return module

return _register
return wrapper

def _register_module(
def _register(
self, module: T, module_name: Optional[str] = None, force: bool = False
) -> None:
if module_name is None:
Expand All @@ -107,7 +157,8 @@ def _register_module(
if not force and module_name in self._module_dict:
existed_module = self._module_dict[module_name]
raise KeyError(
f"`{module_name}` already registred in `{self.name}` registry at `{existed_module.__module__}`."
f"`{module_name}` already registred in `{self.name}` "
f"registry at `{existed_module.__module__}`."
)

self._module_dict[module_name] = module
Expand Down
19 changes: 11 additions & 8 deletions tests/test_utils/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,43 @@ def registry() -> Registry:
def test_creation():
registry = Registry("test")
assert registry.name == "test"
assert not registry.module_dict
assert not registry._module_dict
assert len(registry) == 0
assert str(registry) == "Registry('test')"
assert repr(registry) == "Registry('test')"


def test_registry(registry: Registry):
assert not registry.module_dict
assert not registry._module_dict
assert len(registry) == 0

class A:
pass

registry.register_module(module=A)
registry.register(module=A)
assert registry.get("A") is A

@registry.register_module()
@registry.register()
class B:
pass

assert registry.get("B") is B
assert len(registry) == 2

registry.register_module(name="C", module=A)
registry.register(name="C", module=A)
assert registry.get("C") is A
assert len(registry) == 3

registry.register_module(name="C", module=B, force=True)
registry.register(name="C", module=B, force=True)
assert registry.get("C") is B

registry["D"] = A
assert registry["D"] is A

with pytest.raises(KeyError):
registry.register_module(name="C", module=A, force=False)
registry.register(name="C", module=A, force=False)

@registry.register_module(name="Foo")
@registry.register(name="Foo")
class Bar:
pass

Expand Down

0 comments on commit 466baab

Please sign in to comment.