Skip to content

Commit

Permalink
Improve the dispatch algorithm by avoiding unnecessary ambiguities (#151
Browse files Browse the repository at this point in the history
)

* Break ties for equal signatures

* Add a comment describing what is different from before

* Clarify

* Add example

* Remove `finally`

* Add RTD config

* Rename RTD config

* Fix RTD config

* Update tests/test_signature.py

Co-authored-by: Filippo Vicentini <[email protected]>

* Update tests/test_signature.py

Co-authored-by: Filippo Vicentini <[email protected]>

* Remove duplicate tests

* Add comment for remaining bug

* Reorganise tests

* Add a few more asserts

---------

Co-authored-by: Filippo Vicentini <[email protected]>
  • Loading branch information
wesselb and PhilipVinc authored Jun 2, 2024
1 parent 9854ee0 commit a5a1cd4
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 53 deletions.
82 changes: 61 additions & 21 deletions plum/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,6 @@ def expand_varargs(self, n: int) -> Tuple[TypeHint, ...]:
return self.types

def __le__(self, other) -> bool:
# If this signature has variable arguments, but the other does not, then this
# signature cannot be possibly smaller.
if self.has_varargs and not other.has_varargs:
return False

# If this signature and the other signature both have variable arguments, then
# the variable type of this signature must be less than the variable type of the
# other signature.
if (
self.has_varargs
and other.has_varargs
and not (
beartype.door.TypeHint(self.varargs)
<= beartype.door.TypeHint(other.varargs)
)
):
return False

# If the number of types of the signatures are unequal, then the signature
# with the fewer number of types must be expanded using variable arguments.
if not (
Expand All @@ -176,15 +158,73 @@ def __le__(self, other) -> bool:
):
return False

# Finally, expand the types and compare.
# Expand the types and compare. We implement the subset relationship, but, very
# importantly, deviate from the subset relationship in exactly one place.
self_types = self.expand_varargs(len(other.types))
other_types = other.expand_varargs(len(self.types))
return all(
if all(
[
beartype.door.TypeHint(x) == beartype.door.TypeHint(y)
for x, y in zip(self_types, other_types)
]
):
if self.has_varargs and other.has_varargs:
self_varargs = beartype.door.TypeHint(self.varargs)
other_varargs = beartype.door.TypeHint(other.varargs)
return self_varargs <= other_varargs

# Having variable arguments makes you slightly larger.
elif self.has_varargs:
return False
elif other.has_varargs:
return True

else:
return True

elif all(
[
beartype.door.TypeHint(x) <= beartype.door.TypeHint(y)
for x, y in zip(self_types, other_types)
]
)
):
# In this case, we have that `other >= self` is `False`, so returning `True`
# gives that `other < self` and returning `False` gives that `other` cannot
# be compared to `self`. Regardless of the return value, `other != self`.

if self.has_varargs and other.has_varargs:
# TODO: This implements the subset relationship. However, if the
# variable arguments are not used, then this may unnecessarily
# return `False`. For example, `(int, *A)` would not be
# comparable to `(Number, *B)`. However, if the argument given
# is `1.0`, then reasonably the variable arguments should be
# ignored and `(int, *A)` should be considered more specific
# than `(Number, *B)`.
self_varargs = beartype.door.TypeHint(self.varargs)
other_varargs = beartype.door.TypeHint(other.varargs)
return self_varargs <= other_varargs

elif self.has_varargs:
# Previously, this returned `False`, which would implement the subset
# relationship. We now deviate from the subset relationship! The
# rationale for this is as follows.
#
# A non-variable-arguments signature is compared to a variable-arguments
# signature only to determine which is more specific. At this point, the
# non-variable-arguments signature has number of types equal to the
# number of arguments given to the function, so any additional variable
# arguments are not necessary. Hence, we ignore the additional
# variable arguments in the comparison and return correctly `True`. For
# example, `(int, *int)` would be more specific than `(Number)`.
return True
elif other.has_varargs:
return True

else:
return True

else:
return False

def match(self, values) -> bool:
"""Check whether values match the signature.
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ source = ["plum"]
testpaths = ["tests/", "plum", "docs"]
addopts = [
"-ra",
"-p no:doctest",
"-p",
"no:doctest",
]
minversion = "6.0"

Expand Down
31 changes: 0 additions & 31 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys
import textwrap
import typing
from numbers import Number as Num

import pytest

Expand Down Expand Up @@ -243,33 +242,3 @@ def f(x):
assert r.resolve(m_c1.signature) == m_b1
m_b2.signature.precedence = 2
assert r.resolve(m_c1.signature) == m_b2


def test_117_case():
class A:
pass

class B:
pass

r = Resolver()

def f(x):
return x

m_a = Method(f, Signature(int, varargs=A))
r.register(m_a)
m_b = Method(f, Signature(int, varargs=B))
r.register(m_b)

with pytest.raises(AmbiguousLookupError):
r.resolve((1,))

r = Resolver()
m_a = Method(f, Signature(Num, varargs=int))
r.register(m_a)
m_b = Method(f, Signature(int, varargs=Num))
r.register(m_b)

with pytest.raises(AmbiguousLookupError):
r.resolve((1,))
177 changes: 177 additions & 0 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import pytest

from plum.dispatcher import Dispatcher
from plum.resolver import AmbiguousLookupError
from plum.signature import Signature as Sig
from plum.signature import append_default_args, inspect_signature
from plum.util import Missing
Expand Down Expand Up @@ -114,6 +116,181 @@ def test_expand_varargs():
assert s.expand_varargs(4) == (int, int, float, float)


def test_varargs_tie_breaking():
# These are related to bug #117.

assert Sig(int) < Sig(int, varargs=int)
assert Sig(int, varargs=int) < Sig(int, Num)
assert Sig(int, int, varargs=int) < Sig(int, Num)

assert not Sig(int) >= Sig(int, varargs=int)
assert not Sig(int, varargs=int) >= Sig(int, Num)
assert not Sig(int, int, varargs=int) >= Sig(int, Num)

dispatch = Dispatcher()

@dispatch
def f(*xs: int):
return "ints"

@dispatch
def f(*xs: Num):
return "nums"

@dispatch
def f(x: int):
return "int"

@dispatch
def f(x: int, y: int):
return "two ints"

@dispatch
def f(x: Num):
return "num"

@dispatch
def f(x: Num, y: Num):
return "two nums"

@dispatch
def f(x: int, *ys: int):
return "int and ints"

@dispatch
def f(x: int, *ys: Num):
return "int and nums"

@dispatch
def f(x: Num, *ys: int):
return "num and ints"

@dispatch
def f(x: Num, *ys: Num):
return "num and nums"

assert f(1) == "int"
assert f(1, 1) == "two ints"
assert f(1, 1, 1) == "int and ints"

assert f(1.0) == "num"
assert f(1.0, 1.0) == "two nums"
assert f(1.0, 1.0, 1.0) == "num and nums"

assert f(1, 1.0) == "int and nums"
assert f(1.0, 1) == "num and ints"

assert f(1, 1, 1.0) == "int and nums"
assert f(1.0, 1.0, 1) == "num and nums"
assert f(1, 1.0, 1.0) == "int and nums"
assert f(1.0, 1, 1) == "num and ints"


def test_117_case1():
dispatch = Dispatcher()

class A:
pass

class B:
pass

@dispatch
def f(x: int, *a: A):
return "int and As"

@dispatch
def f(x: int, *a: B):
return "int and Bs"

with pytest.raises(AmbiguousLookupError):
f(1)
assert f(1, A()) == "int and As"
assert f(1, B()) == "int and Bs"


@pytest.mark.xfail(reason="bug #117")
def test_117_case2():
dispatch = Dispatcher()

class A:
pass

class B:
pass

@dispatch
def f(x: int, *a: A):
return "int and As"

@dispatch
def f(x: Num, *a: B):
return "num and Bs"

assert f(1) == "int and As"
assert f(1, A()) == "int and As"
assert f(1.0) == "num and Bs"
assert f(1.0, B()) == "num and Bs"


def test_117_case3():
dispatch = Dispatcher()

class A:
pass

class B:
pass

@dispatch
def f(x: int, *a: A):
return "int and As"

@dispatch
def f(x: int, *a: B):
return "int and Bs"

@dispatch
def f(x: Num, *a: B):
return "num and Bs"

with pytest.raises(AmbiguousLookupError):
f(1)
assert f(1, A()) == "int and As"
assert f(1, B()) == "int and Bs"
assert f(1.0) == "num and Bs"
assert f(1.0, B()) == "num and Bs"


def test_varargs_subset():
assert Sig(int, varargs=int) == Sig(int, varargs=int)
assert Sig(int, varargs=int) < Sig(Num, varargs=int)
assert Sig(int, varargs=int) < Sig(int, varargs=Num)
assert Sig(int, varargs=int) < Sig(Num, varargs=Num)
assert Sig(int, varargs=Num) == Sig(int, varargs=Num)
assert Sig(int, varargs=Num) < Sig(Num, varargs=Num)
assert Sig(Num, varargs=int) == Sig(Num, varargs=int)
assert Sig(Num, varargs=int) < Sig(Num, varargs=Num)
assert Sig(Num, varargs=Num) == Sig(Num, varargs=Num)

assert not Sig(Num, varargs=int) <= Sig(int, varargs=int)
assert not Sig(int, varargs=Num) <= Sig(int, varargs=int)
assert not Sig(Num, varargs=Num) <= Sig(int, varargs=int)
assert not Sig(int, varargs=Num) <= Sig(Num, varargs=int)
assert not Sig(Num, varargs=Num) <= Sig(Num, varargs=int)
assert not Sig(Num, varargs=int) <= Sig(int, varargs=Num)
assert not Sig(Num, varargs=Num) <= Sig(int, varargs=Num)

class A:
pass

class B:
pass

assert not Sig(int, varargs=A) <= Sig(int, varargs=B)
assert not Sig(int, varargs=B) <= Sig(int, varargs=A)


def test_comparison():
# Variable arguments shortcuts:
assert not Sig(varargs=int) <= Sig()
Expand Down

0 comments on commit a5a1cd4

Please sign in to comment.