Skip to content

Commit

Permalink
gh-129463: gh-128593: Simplify ForwardRef
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra committed Jan 30, 2025
1 parent a472244 commit 060ca69
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 147 deletions.
28 changes: 0 additions & 28 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ class Format(enum.IntEnum):
# preserved for compatibility with the old typing.ForwardRef class. The remaining
# names are private.
_SLOTS = (
"__forward_evaluated__",
"__forward_value__",
"__forward_is_argument__",
"__forward_is_class__",
"__forward_module__",
Expand Down Expand Up @@ -78,8 +76,6 @@ def __init__(
raise TypeError(f"Forward reference must be a string -- got {arg!r}")

self.__arg__ = arg
self.__forward_evaluated__ = False
self.__forward_value__ = None
self.__forward_is_argument__ = is_argument
self.__forward_is_class__ = is_class
self.__forward_module__ = module
Expand All @@ -97,16 +93,12 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
If the forward reference cannot be evaluated, raise an exception.
"""
if self.__forward_evaluated__:
return self.__forward_value__
if self.__cell__ is not None:
try:
value = self.__cell__.cell_contents
except ValueError:
pass
else:
self.__forward_evaluated__ = True
self.__forward_value__ = value
return value
if owner is None:
owner = self.__owner__
Expand Down Expand Up @@ -173,8 +165,6 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
else:
code = self.__forward_code__
value = eval(code, globals=globals, locals=locals)
self.__forward_evaluated__ = True
self.__forward_value__ = value
return value

def _evaluate(self, globalns, localns, type_params=_sentinel, *, recursive_guard):
Expand Down Expand Up @@ -229,22 +219,6 @@ def __forward_code__(self):
raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}")
return self.__code__

def __eq__(self, other):
if not isinstance(other, ForwardRef):
return NotImplemented
if self.__forward_evaluated__ and other.__forward_evaluated__:
return (
self.__forward_arg__ == other.__forward_arg__
and self.__forward_value__ == other.__forward_value__
)
return (
self.__forward_arg__ == other.__forward_arg__
and self.__forward_module__ == other.__forward_module__
)

def __hash__(self):
return hash((self.__forward_arg__, self.__forward_module__))

def __or__(self, other):
global _Union
if _Union is None:
Expand Down Expand Up @@ -284,8 +258,6 @@ def __init__(
# represent a single name).
assert isinstance(node, (ast.AST, str))
self.__arg__ = None
self.__forward_evaluated__ = False
self.__forward_value__ = None
self.__forward_is_argument__ = False
self.__forward_is_class__ = is_class
self.__forward_module__ = None
Expand Down
54 changes: 21 additions & 33 deletions Lib/test/test_annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def wrapper(a, b):
return wrapper


def assert_is_fwdref(case, obj, value):
case.assertIsInstance(obj, annotationlib.ForwardRef)
case.assertEqual(obj.__forward_arg__, value)


class MyClass:
def __repr__(self):
return "my repr"
Expand Down Expand Up @@ -59,8 +64,7 @@ def inner(arg: x):

anno = annotationlib.get_annotations(inner, format=Format.FORWARDREF)
fwdref = anno["arg"]
self.assertIsInstance(fwdref, annotationlib.ForwardRef)
self.assertEqual(fwdref.__forward_arg__, "x")
assert_is_fwdref(self, fwdref, "x")
with self.assertRaises(NameError):
fwdref.evaluate()

Expand All @@ -77,8 +81,7 @@ def f(x: int, y: doesntexist):
anno = annotationlib.get_annotations(f, format=Format.FORWARDREF)
self.assertIs(anno["x"], int)
fwdref = anno["y"]
self.assertIsInstance(fwdref, annotationlib.ForwardRef)
self.assertEqual(fwdref.__forward_arg__, "doesntexist")
assert_is_fwdref(self, fwdref, "doesntexist")
with self.assertRaises(NameError):
fwdref.evaluate()
self.assertEqual(fwdref.evaluate(globals={"doesntexist": 1}), 1)
Expand All @@ -96,28 +99,22 @@ def f(

anno = annotationlib.get_annotations(f, format=Format.FORWARDREF)
x_anno = anno["x"]
self.assertIsInstance(x_anno, ForwardRef)
self.assertEqual(x_anno, ForwardRef("some.module"))
assert_is_fwdref(self, x_anno, "some.module")

y_anno = anno["y"]
self.assertIsInstance(y_anno, ForwardRef)
self.assertEqual(y_anno, ForwardRef("some[module]"))
assert_is_fwdref(self, y_anno, "some[module]")

z_anno = anno["z"]
self.assertIsInstance(z_anno, ForwardRef)
self.assertEqual(z_anno, ForwardRef("some(module)"))
assert_is_fwdref(self, z_anno, "some(module)")

alpha_anno = anno["alpha"]
self.assertIsInstance(alpha_anno, ForwardRef)
self.assertEqual(alpha_anno, ForwardRef("some | obj"))
assert_is_fwdref(self, alpha_anno, "some | obj")

beta_anno = anno["beta"]
self.assertIsInstance(beta_anno, ForwardRef)
self.assertEqual(beta_anno, ForwardRef("+some"))
assert_is_fwdref(self, beta_anno, "+some")

gamma_anno = anno["gamma"]
self.assertIsInstance(gamma_anno, ForwardRef)
self.assertEqual(gamma_anno, ForwardRef("some < obj"))
assert_is_fwdref(self, gamma_anno, "some < obj")


class TestSourceFormat(unittest.TestCase):
Expand Down Expand Up @@ -362,13 +359,6 @@ def test_fwdref_to_builtin(self):
obj = object()
self.assertIs(ForwardRef("int").evaluate(globals={"int": obj}), obj)

def test_fwdref_value_is_cached(self):
fr = ForwardRef("hello")
with self.assertRaises(NameError):
fr.evaluate()
self.assertIs(fr.evaluate(globals={"hello": str}), str)
self.assertIs(fr.evaluate(), str)

def test_fwdref_with_owner(self):
self.assertEqual(
ForwardRef("Counter[int]", owner=collections).evaluate(),
Expand Down Expand Up @@ -457,12 +447,10 @@ def f2(a: undefined):
)
self.assertEqual(annotationlib.get_annotations(f1, format=1), {"a": int})

fwd = annotationlib.ForwardRef("undefined")
self.assertEqual(
annotationlib.get_annotations(f2, format=Format.FORWARDREF),
{"a": fwd},
)
self.assertEqual(annotationlib.get_annotations(f2, format=3), {"a": fwd})
for fmt in (Format.FORWARDREF, 3):
annos = annotationlib.get_annotations(f2, format=fmt)
self.assertEqual(list(annos), ["a"])
assert_is_fwdref(self, annos["a"], "undefined")

self.assertEqual(
annotationlib.get_annotations(f1, format=Format.STRING),
Expand Down Expand Up @@ -1012,10 +1000,10 @@ def evaluate(format, exc=NotImplementedError):

with self.assertRaises(NameError):
annotationlib.call_evaluate_function(evaluate, Format.VALUE)
self.assertEqual(
annotationlib.call_evaluate_function(evaluate, Format.FORWARDREF),
annotationlib.ForwardRef("undefined"),
)

fwdref = annotationlib.call_evaluate_function(evaluate, Format.FORWARDREF)
assert_is_fwdref(self, fwdref, "undefined")

self.assertEqual(
annotationlib.call_evaluate_function(evaluate, Format.STRING),
"undefined",
Expand Down
110 changes: 24 additions & 86 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ def wrapper(self):
return wrapper


class EqualToForwardRef:
def __init__(self, arg, module=None):
self.arg = arg
self.module = module

def __eq__(self, other):
if not isinstance(other, ForwardRef):
return NotImplemented
return self.arg == other.__forward_arg__ and self.module == other.__forward_module__


class Employee:
pass

Expand Down Expand Up @@ -467,8 +478,8 @@ def test_or(self):
self.assertEqual(X | "x", Union[X, "x"])
self.assertEqual("x" | X, Union["x", X])
# make sure the order is correct
self.assertEqual(get_args(X | "x"), (X, ForwardRef("x")))
self.assertEqual(get_args("x" | X), (ForwardRef("x"), X))
self.assertEqual(get_args(X | "x"), (X, EqualToForwardRef("x")))
self.assertEqual(get_args("x" | X), (EqualToForwardRef("x"), X))

def test_union_constrained(self):
A = TypeVar('A', str, bytes)
Expand Down Expand Up @@ -4965,7 +4976,7 @@ class C3:
def f(x: X): ...
self.assertEqual(
get_type_hints(f, globals(), locals()),
{'x': list[list[ForwardRef('X')]]}
{'x': list[list[EqualToForwardRef('X')]]}
)

def test_pep695_generic_class_with_future_annotations(self):
Expand Down Expand Up @@ -5183,12 +5194,15 @@ class Node(Generic[T]): ...
Callable[..., T], Callable[[int], int],
Tuple[Any, Any], Node[T], Node[int], Node[Any], typing.Iterable[T],
typing.Iterable[Any], typing.Iterable[int], typing.Dict[int, str],
typing.Dict[T, Any], ClassVar[int], ClassVar[List[T]], Tuple['T', 'T'],
Union['T', int], List['T'], typing.Mapping['T', int]]
typing.Dict[T, Any], ClassVar[int], ClassVar[List[T]]]
for t in things + [Any]:
self.assertEqual(t, copy(t))
self.assertEqual(t, deepcopy(t))

shallow_things = [Tuple['T', 'T'], Union['T', int], List['T'], typing.Mapping['T', int]]
for t in things + [Any]:
self.assertEqual(t, copy(t))

def test_immutability_by_copy_and_pickle(self):
# Special forms like Union, Any, etc., generic aliases to containers like List,
# Mapping, etc., and type variabcles are considered immutable by copy and pickle.
Expand Down Expand Up @@ -6087,82 +6101,6 @@ def test_forwardref_only_str_arg(self):
with self.assertRaises(TypeError):
typing.ForwardRef(1) # only `str` type is allowed

def test_forward_equality(self):
fr = typing.ForwardRef('int')
self.assertEqual(fr, typing.ForwardRef('int'))
self.assertNotEqual(List['int'], List[int])
self.assertNotEqual(fr, typing.ForwardRef('int', module=__name__))
frm = typing.ForwardRef('int', module=__name__)
self.assertEqual(frm, typing.ForwardRef('int', module=__name__))
self.assertNotEqual(frm, typing.ForwardRef('int', module='__other_name__'))

def test_forward_equality_gth(self):
c1 = typing.ForwardRef('C')
c1_gth = typing.ForwardRef('C')
c2 = typing.ForwardRef('C')
c2_gth = typing.ForwardRef('C')

class C:
pass
def foo(a: c1_gth, b: c2_gth):
pass

self.assertEqual(get_type_hints(foo, globals(), locals()), {'a': C, 'b': C})
self.assertEqual(c1, c2)
self.assertEqual(c1, c1_gth)
self.assertEqual(c1_gth, c2_gth)
self.assertEqual(List[c1], List[c1_gth])
self.assertNotEqual(List[c1], List[C])
self.assertNotEqual(List[c1_gth], List[C])
self.assertEqual(Union[c1, c1_gth], Union[c1])
self.assertEqual(Union[c1, c1_gth, int], Union[c1, int])

def test_forward_equality_hash(self):
c1 = typing.ForwardRef('int')
c1_gth = typing.ForwardRef('int')
c2 = typing.ForwardRef('int')
c2_gth = typing.ForwardRef('int')

def foo(a: c1_gth, b: c2_gth):
pass
get_type_hints(foo, globals(), locals())

self.assertEqual(hash(c1), hash(c2))
self.assertEqual(hash(c1_gth), hash(c2_gth))
self.assertEqual(hash(c1), hash(c1_gth))

c3 = typing.ForwardRef('int', module=__name__)
c4 = typing.ForwardRef('int', module='__other_name__')

self.assertNotEqual(hash(c3), hash(c1))
self.assertNotEqual(hash(c3), hash(c1_gth))
self.assertNotEqual(hash(c3), hash(c4))
self.assertEqual(hash(c3), hash(typing.ForwardRef('int', module=__name__)))

def test_forward_equality_namespace(self):
class A:
pass
def namespace1():
a = typing.ForwardRef('A')
def fun(x: a):
pass
get_type_hints(fun, globals(), locals())
return a

def namespace2():
a = typing.ForwardRef('A')

class A:
pass
def fun(x: a):
pass

get_type_hints(fun, globals(), locals())
return a

self.assertEqual(namespace1(), namespace1())
self.assertNotEqual(namespace1(), namespace2())

def test_forward_repr(self):
self.assertEqual(repr(List['int']), "typing.List[ForwardRef('int')]")
self.assertEqual(repr(List[ForwardRef('int', module='mod')]),
Expand Down Expand Up @@ -6226,7 +6164,7 @@ def cmp(o1, o2):
r1 = namespace1()
r2 = namespace2()
self.assertIsNot(r1, r2)
self.assertRaises(RecursionError, cmp, r1, r2)
self.assertNotEqual(r1, r2)

def test_union_forward_recursion(self):
ValueList = List['Value']
Expand Down Expand Up @@ -7146,7 +7084,7 @@ def func(x: undefined) -> undefined: ...
# FORWARDREF
self.assertEqual(
get_type_hints(func, format=annotationlib.Format.FORWARDREF),
{'x': ForwardRef('undefined'), 'return': ForwardRef('undefined')},
{'x': EqualToForwardRef('undefined'), 'return': EqualToForwardRef('undefined')},
)

# STRING
Expand Down Expand Up @@ -8030,7 +7968,7 @@ class Y(NamedTuple):
class Z(NamedTuple):
a: None
b: "str"
annos = {'a': type(None), 'b': ForwardRef("str")}
annos = {'a': type(None), 'b': EqualToForwardRef("str")}
self.assertEqual(Z.__annotations__, annos)
self.assertEqual(Z.__annotate__(annotationlib.Format.VALUE), annos)
self.assertEqual(Z.__annotate__(annotationlib.Format.FORWARDREF), annos)
Expand All @@ -8046,7 +7984,7 @@ class X(NamedTuple):
"""
ns = run_code(textwrap.dedent(code))
X = ns['X']
self.assertEqual(X.__annotations__, {'a': ForwardRef("int"), 'b': ForwardRef("None")})
self.assertEqual(X.__annotations__, {'a': EqualToForwardRef("int"), 'b': EqualToForwardRef("None")})

def test_deferred_annotations(self):
class X(NamedTuple):
Expand Down Expand Up @@ -9032,7 +8970,7 @@ class X(TypedDict):
class Y(TypedDict):
a: None
b: "int"
fwdref = ForwardRef('int', module=__name__)
fwdref = EqualToForwardRef('int', module=__name__)
self.assertEqual(Y.__annotations__, {'a': type(None), 'b': fwdref})
self.assertEqual(Y.__annotate__(annotationlib.Format.FORWARDREF), {'a': type(None), 'b': fwdref})

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
:class:`annotationlib.ForwardRef` objects no longer cache their value when
they are successfully evaluated. Successive calls to
:meth:`annotationlib.ForwardRef.evaluate` may return different values.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
:class:`annotationlib.ForwardRef` objects no longer compare or hash equal
when they refer to the same string. The implementation of equality was
error-prone because it did not take all attributes of the
:class:`!ForwardRef` object into account.

0 comments on commit 060ca69

Please sign in to comment.