Skip to content

Commit

Permalink
Break ties for equal signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed May 24, 2024
1 parent bb657d7 commit 03c4276
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 21 deletions.
54 changes: 34 additions & 20 deletions plum/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,6 @@ def expand_varargs(self, n: int) -> Tuple[TypeHint, ...]:
return self.types

def __le__(self, other) -> bool:
# If this signature has variable arguments, but the other does not, then this
# signature cannot be possibly smaller.
if self.has_varargs and not other.has_varargs:
return False

# If this signature and the other signature both have variable arguments, then
# the variable type of this signature must be less than the variable type of the
# other signature.
if (
self.has_varargs
and other.has_varargs
and not (
beartype.door.TypeHint(self.varargs)
<= beartype.door.TypeHint(other.varargs)
)
):
return False

# If the number of types of the signatures are unequal, then the signature
# with the fewer number of types must be expanded using variable arguments.
if not (
Expand All @@ -179,12 +161,44 @@ def __le__(self, other) -> bool:
# Finally, expand the types and compare.
self_types = self.expand_varargs(len(other.types))
other_types = other.expand_varargs(len(self.types))
return all(
if all(
[
beartype.door.TypeHint(x) == beartype.door.TypeHint(y)
for x, y in zip(self_types, other_types)
]
):
# If both have variable arguments, implement the subset relationship.
if self.has_varargs and other.has_varargs:
self_varargs = beartype.door.TypeHint(self.varargs)
other_varargs = beartype.door.TypeHint(other.varargs)
return self_varargs <= other_varargs

# Having varargs makes you slightly larger.
elif self.has_varargs:
return False
elif other.has_varargs:
return True

else:
return True

elif all(
[
beartype.door.TypeHint(x) <= beartype.door.TypeHint(y)
for x, y in zip(self_types, other_types)
]
)
):
# If both have variable arguments, implement the subset relationship.
if self.has_varargs and other.has_varargs:
self_varargs = beartype.door.TypeHint(self.varargs)
other_varargs = beartype.door.TypeHint(other.varargs)
return self_varargs <= other_varargs

else:
return True

else:
return False

def match(self, values) -> bool:
"""Check whether values match the signature.
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ source = ["plum"]
testpaths = ["tests/", "plum", "docs"]
addopts = [
"-ra",
"-p no:doctest",
"-p",
"no:doctest",
]
minversion = "6.0"

Expand Down
29 changes: 29 additions & 0 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,35 @@ def test_expand_varargs():
assert s.expand_varargs(4) == (int, int, float, float)


def test_varargs_tie_breaking():
assert Sig(int) < Sig(int, varargs=int)
assert not Sig(int) >= Sig(int, varargs=int)
assert Sig(int, varargs=int) < Sig(int, Num)
assert not Sig(int, varargs=int) >= Sig(int, Num)
assert Sig(int, int, varargs=int) < Sig(int, Num)
assert not Sig(int, int, varargs=int) >= Sig(int, Num)


def test_varargs_subset():
assert Sig(int, varargs=int) <= Sig(int, varargs=int)
assert Sig(int, varargs=int) <= Sig(Num, varargs=int)
assert Sig(int, varargs=int) <= Sig(int, varargs=Num)
assert Sig(int, varargs=int) <= Sig(Num, varargs=Num)
assert Sig(int, varargs=Num) <= Sig(int, varargs=Num)
assert Sig(int, varargs=Num) <= Sig(Num, varargs=Num)
assert Sig(Num, varargs=int) <= Sig(Num, varargs=int)
assert Sig(Num, varargs=int) <= Sig(Num, varargs=Num)
assert Sig(Num, varargs=Num) <= Sig(Num, varargs=Num)

assert not Sig(Num, varargs=int) <= Sig(int, varargs=int)
assert not Sig(int, varargs=Num) <= Sig(int, varargs=int)
assert not Sig(Num, varargs=Num) <= Sig(int, varargs=int)
assert not Sig(int, varargs=Num) <= Sig(Num, varargs=int)
assert not Sig(Num, varargs=Num) <= Sig(Num, varargs=int)
assert not Sig(Num, varargs=int) <= Sig(int, varargs=Num)
assert not Sig(Num, varargs=Num) <= Sig(int, varargs=Num)


def test_comparison():
# Variable arguments shortcuts:
assert not Sig(varargs=int) <= Sig()
Expand Down

0 comments on commit 03c4276

Please sign in to comment.