Skip to content

Commit

Permalink
Add allow_fail keyword to ModuleType
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Apr 21, 2024
1 parent ac4c1c7 commit 3b6afc4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
13 changes: 10 additions & 3 deletions plum/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,28 +87,35 @@ class ModuleType(ResolvableType):
Args:
module (str): Module that the type lives in.
name (str): Name of the type that is promised.
allow_fail (bool, optional): If the type is does not exist in `module`,
do not raise an `AttributeError`.
"""

def __init__(self, module, name):
def __init__(self, module, name, allow_fail=False):
if module in {"__builtin__", "__builtins__"}:
module = "builtins"
ResolvableType.__init__(self, f"ModuleType[{module}.{name}]")
self._name = name
self._module = module
self._allow_fail = allow_fail

def __new__(cls, module, name):
def __new__(cls, module, name, allow_fail=False):
return ResolvableType.__new__(cls, f"ModuleType[{module}.{name}]")

def retrieve(self):
"""Attempt to retrieve the type from the reference module.
Returns:
:class:`ModuleType`: `self`.
bool: Whether the retrieval succeeded.
"""
if self._type is None:
if self._module in sys.modules:
type = sys.modules[self._module]
for name in self._name.split("."):
# If `type` does not contain `name` and `self._allow_fail` is
# set, then silently fail.
if not hasattr(type, name) and self._allow_fail:
return False
type = getattr(type, name)
self.deliver(type)
return self._type is not None
Expand Down
10 changes: 10 additions & 0 deletions tests/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def test_moduletype(module, name, type):
assert not t.retrieve()


def test_moduletype_allow_fail():
t_not_allowed = ModuleType("__builtin__", "nonexisting")
t_allowed = ModuleType("__builtin__", "nonexisting", allow_fail=True)

with pytest.raises(AttributeError):
t_not_allowed.retrieve()

assert not t_allowed.retrieve()


def test_is_hint():
assert not _is_hint(int)
assert _is_hint(typing.Union[int, float])
Expand Down

0 comments on commit 3b6afc4

Please sign in to comment.