diff --git a/plum/signature.py b/plum/signature.py index 83548fa..d71594a 100644 --- a/plum/signature.py +++ b/plum/signature.py @@ -149,24 +149,6 @@ def expand_varargs(self, n: int) -> Tuple[TypeHint, ...]: return self.types def __le__(self, other) -> bool: - # If this signature has variable arguments, but the other does not, then this - # signature cannot be possibly smaller. - if self.has_varargs and not other.has_varargs: - return False - - # If this signature and the other signature both have variable arguments, then - # the variable type of this signature must be less than the variable type of the - # other signature. - if ( - self.has_varargs - and other.has_varargs - and not ( - beartype.door.TypeHint(self.varargs) - <= beartype.door.TypeHint(other.varargs) - ) - ): - return False - # If the number of types of the signatures are unequal, then the signature # with the fewer number of types must be expanded using variable arguments. if not ( @@ -179,12 +161,44 @@ def __le__(self, other) -> bool: # Finally, expand the types and compare. self_types = self.expand_varargs(len(other.types)) other_types = other.expand_varargs(len(self.types)) - return all( + if all( + [ + beartype.door.TypeHint(x) == beartype.door.TypeHint(y) + for x, y in zip(self_types, other_types) + ] + ): + # If both have variable arguments, implement the subset relationship. + if self.has_varargs and other.has_varargs: + self_varargs = beartype.door.TypeHint(self.varargs) + other_varargs = beartype.door.TypeHint(other.varargs) + return self_varargs <= other_varargs + + # Having varargs makes you slightly larger. + elif self.has_varargs: + return False + elif other.has_varargs: + return True + + else: + return True + + elif all( [ beartype.door.TypeHint(x) <= beartype.door.TypeHint(y) for x, y in zip(self_types, other_types) ] - ) + ): + # If both have variable arguments, implement the subset relationship. + if self.has_varargs and other.has_varargs: + self_varargs = beartype.door.TypeHint(self.varargs) + other_varargs = beartype.door.TypeHint(other.varargs) + return self_varargs <= other_varargs + + else: + return True + + else: + return False def match(self, values) -> bool: """Check whether values match the signature. diff --git a/pyproject.toml b/pyproject.toml index 95e9f92..b22f4f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,8 @@ source = ["plum"] testpaths = ["tests/", "plum", "docs"] addopts = [ "-ra", - "-p no:doctest", + "-p", + "no:doctest", ] minversion = "6.0" diff --git a/tests/test_signature.py b/tests/test_signature.py index 57f0729..c7f54e2 100644 --- a/tests/test_signature.py +++ b/tests/test_signature.py @@ -114,6 +114,35 @@ def test_expand_varargs(): assert s.expand_varargs(4) == (int, int, float, float) +def test_varargs_tie_breaking(): + assert Sig(int) < Sig(int, varargs=int) + assert not Sig(int) >= Sig(int, varargs=int) + assert Sig(int, varargs=int) < Sig(int, Num) + assert not Sig(int, varargs=int) >= Sig(int, Num) + assert Sig(int, int, varargs=int) < Sig(int, Num) + assert not Sig(int, int, varargs=int) >= Sig(int, Num) + + +def test_varargs_subset(): + assert Sig(int, varargs=int) <= Sig(int, varargs=int) + assert Sig(int, varargs=int) <= Sig(Num, varargs=int) + assert Sig(int, varargs=int) <= Sig(int, varargs=Num) + assert Sig(int, varargs=int) <= Sig(Num, varargs=Num) + assert Sig(int, varargs=Num) <= Sig(int, varargs=Num) + assert Sig(int, varargs=Num) <= Sig(Num, varargs=Num) + assert Sig(Num, varargs=int) <= Sig(Num, varargs=int) + assert Sig(Num, varargs=int) <= Sig(Num, varargs=Num) + assert Sig(Num, varargs=Num) <= Sig(Num, varargs=Num) + + assert not Sig(Num, varargs=int) <= Sig(int, varargs=int) + assert not Sig(int, varargs=Num) <= Sig(int, varargs=int) + assert not Sig(Num, varargs=Num) <= Sig(int, varargs=int) + assert not Sig(int, varargs=Num) <= Sig(Num, varargs=int) + assert not Sig(Num, varargs=Num) <= Sig(Num, varargs=int) + assert not Sig(Num, varargs=int) <= Sig(int, varargs=Num) + assert not Sig(Num, varargs=Num) <= Sig(int, varargs=Num) + + def test_comparison(): # Variable arguments shortcuts: assert not Sig(varargs=int) <= Sig()