Skip to content

Commit

Permalink
renamed
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Jan 16, 2025
1 parent 3715991 commit 5939456
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 27 deletions.
77 changes: 50 additions & 27 deletions luxonis_ml/utils/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from abc import ABCMeta
from functools import wraps
from typing import (
Callable,
Dict,
Expand All @@ -19,7 +19,7 @@ def __init__(self, name: str):
"""A Registry class to store and retrieve modules.
@type name: str
@param name: Name of the registry
@ivar name: Name of the registry
"""
self._module_dict: Dict[str, T] = {}
self._name = name
Expand All @@ -33,14 +33,16 @@ def __repr__(self):
def __len__(self):
return len(self._module_dict)

def __getitem__(self, key: str) -> T:
return self.get(key)

def __setitem__(self, key: str, value: T) -> None:
self._register(module=value, module_name=key, force=True)

@property
def name(self):
return self._name

@property
def module_dict(self):
return self._module_dict

def get(self, key: str) -> T:
"""Retrieves the registry record for the key.
Expand All @@ -53,32 +55,57 @@ def get(self, key: str) -> T:
"""
module_cls = self._module_dict.get(key, None)
if module_cls is None:
raise KeyError(f"Class `{key}` not in the `{self.name}` registry.")
raise KeyError(f"'{key}' not in the '{self.name}' registry.")
else:
return module_cls

@overload
def register_module(
self, name: Optional[str] = ..., module: None = ..., force: bool = ...
) -> Callable[[T], T]: ...

@overload
def register_module(
self, name: Optional[str] = ..., module: T = ..., force: bool = ...
) -> T: ...

def register_module(
self,
name: Optional[str] = ...,
name: Optional[str] = None,
module: Optional[T] = None,
force: bool = False,
) -> Union[T, Callable[[T], T]]:
warnings.warn(
"`register_module` is deprecated, use `register` instead."
)

return self.register(name=name, module=module, force=force)

@overload
def register(
self,
module: None = ...,
*,
name: Optional[str] = ...,
force: bool = ...,
) -> Callable[[T], T]: ...

@overload
def register_module(
def register(
self,
name: Optional[str] = ...,
module: T = ...,
*,
name: Optional[str] = ...,
force: bool = ...,
) -> T: ...
) -> None: ...

def register_module(
def register(
self,
name: Optional[str] = None,
module: Optional[T] = None,
*,
name: Optional[str] = None,
force: bool = False,
) -> Union[T, Callable[[T], T]]:
) -> Optional[Callable[[T], T]]:
"""Registers a module.
Can be used as a decorator or as a normal method:
Expand Down Expand Up @@ -113,32 +140,28 @@ def register_module(
"""

if module is not None:
return self._register_module(
module=module, module_name=name, force=force
)
return self._register(module=module, module_name=name, force=force)

@wraps
def _register(module: T) -> T:
return self._register_module(
module=module, module_name=name, force=force
)
def wrapper(module: T) -> T:
self._register(module=module, module_name=name, force=force)
return module

return _register # type: ignore
return wrapper

def _register_module(
def _register(
self, module: T, module_name: Optional[str] = None, force: bool = False
) -> T:
) -> None:
if module_name is None:
module_name = module.__name__

if not force and module_name in self._module_dict:
existed_module = self._module_dict[module_name]
raise KeyError(
f"`{module_name}` already registred in `{self.name}` registry at `{existed_module.__module__}`."
f"`{module_name}` already registred in `{self.name}` "
f"registry at `{existed_module.__module__}`."
)

self._module_dict[module_name] = module
return module


class AutoRegisterMeta(ABCMeta):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_utils/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class B:
registry.register(name="C", module=B, force=True)
assert registry.get("C") is B

registry["D"] = A
assert registry["D"] is A

with pytest.raises(KeyError):
registry.register(name="C", module=A, force=False)

Expand Down

0 comments on commit 5939456

Please sign in to comment.